diff --git a/.github/copy-pr-bot.yaml b/.github/copy-pr-bot.yaml new file mode 100644 index 0000000000..895ba83ee5 --- /dev/null +++ b/.github/copy-pr-bot.yaml @@ -0,0 +1,4 @@ +# Configuration file for `copy-pr-bot` GitHub App +# https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/ + +enabled: true diff --git a/.github/ops-bot.yaml b/.github/ops-bot.yaml deleted file mode 100644 index 84bbe71f46..0000000000 --- a/.github/ops-bot.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# This file controls which features from the `ops-bot` repository below are enabled. -# - https://github.com/rapidsai/ops-bot - -copy_prs: true diff --git a/.github/workflows/cpu-horovod.yml b/.github/workflows/cpu-horovod.yml index 13c93d2588..69dd024858 100644 --- a/.github/workflows/cpu-horovod.yml +++ b/.github/workflows/cpu-horovod.yml @@ -72,4 +72,4 @@ jobs: if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then extra_pytest_markers="and changed" fi - EXTRA_PYTEST_MARKERS="$extra_pytest_markers" MERLIN_BRANCH="$merlin_branch" COMPARE_BRANCH=${{ github.base_ref }} tox -e horovod-cpu + PYTEST_MARKERS="$extra_pytest_markers" MERLIN_BRANCH="$merlin_branch" COMPARE_BRANCH=${{ github.base_ref }} tox -e horovod-cpu diff --git a/.github/workflows/gpu-multi.yml b/.github/workflows/gpu-multi.yml index 3d47558932..dd0cc5934b 100644 --- a/.github/workflows/gpu-multi.yml +++ b/.github/workflows/gpu-multi.yml @@ -56,4 +56,47 @@ jobs: if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then extra_pytest_markers="and changed" fi - cd ${{ github.workspace }}; EXTRA_PYTEST_MARKERS=$extra_pytest_markers MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e multi-gpu + cd ${{ github.workspace }}; PYTEST_MARKERS="multigpu $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu,horovod-gpu + + check-changes-torch: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install GitPython + pip install . --no-deps + - name: Get changed backends + id: backend_check + run: | + echo "changed=$(python ci/get_changed_backends.py --backend torch --branch ${{github.base_ref}})" >> "$GITHUB_OUTPUT" + outputs: + needs_testing: ${{ steps.backend_check.outputs.changed }} + + torch: + needs: check-changes-torch + if: ${{needs.check-changes-torch.outputs.needs_testing == 'true' || github.ref == 'refs/heads/main'}} + runs-on: 2GPU + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Run tests + run: | + ref_type=${{ github.ref_type }} + branch=main + if [[ $ref_type == "tag"* ]] + then + git -c protocol.version=2 fetch --no-tags --prune --progress --no-recurse-submodules --depth=1 origin +refs/heads/release*:refs/remotes/origin/release* + branch=$(git branch -r --contains ${{ github.ref_name }} --list '*release*' --format "%(refname:short)" | sed -e 's/^origin\///') + fi + if [[ "${{ github.ref }}" != 'refs/heads/main' ]]; then + extra_pytest_markers="and changed" + fi + cd ${{ github.workspace }}; PYTEST_MARKERS="multigpu $extra_pytest_markers" MERLIN_BRANCH=$branch COMPARE_BRANCH=${{ github.base_ref }} tox -e gpu diff --git a/.github/workflows/gpu.yml b/.github/workflows/gpu.yml index 58ff8c862e..20fade2647 100644 --- a/.github/workflows/gpu.yml +++ b/.github/workflows/gpu.yml @@ -143,4 +143,4 @@ jobs: merlin_branch="${{ steps.get-branch-name.outputs.branch }}" RAPIDS_VERSION=${{ matrix.version.rapids }} MERLIN_BRANCH=$merlin_branch COMPARE_BRANCH=$merlin_branch \ PYTEST_MARKERS="(examples or notebook) $extra_pytest_markers" \ - tox -e gpu-cu11 + tox -e gpu-cu11 \ No newline at end of file diff --git a/.gitignore b/.gitignore index e5f5cb5bcc..ed7bef7e91 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,6 @@ dmypy.json # Experiment files _test.py + +# Lightning +**/lightning_logs/ \ No newline at end of file diff --git a/docs/source/_static/NVIDIA-LogoBlack.svg b/docs/source/_static/NVIDIA-LogoBlack.svg new file mode 100755 index 0000000000..c612396c71 --- /dev/null +++ b/docs/source/_static/NVIDIA-LogoBlack.svg @@ -0,0 +1 @@ +NVIDIA-LogoBlack \ No newline at end of file diff --git a/docs/source/_static/NVIDIA-LogoWhite.svg b/docs/source/_static/NVIDIA-LogoWhite.svg new file mode 100755 index 0000000000..942ca3b2a0 --- /dev/null +++ b/docs/source/_static/NVIDIA-LogoWhite.svg @@ -0,0 +1,58 @@ + + + + + + + NVIDIA-LogoBlack + + + + + diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css index 319ddff89a..7287e49212 100644 --- a/docs/source/_static/css/custom.css +++ b/docs/source/_static/css/custom.css @@ -1,34 +1,472 @@ -.wy-nav-content { - margin: 0; - background: #fcfcfc; - padding-top: 40px; +/* +# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +/* Parts of this are adapted from the NVIDIA Omniverse Docs Sphinx Theme */ + +/* Set up for old browsers*/ +@supports not (font-variation-settings: normal) { + @font-face { + font-family: "NVIDIA"; + src: url("https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Lt.woff") format("woff"), + url("https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Lt.woff2") format("woff2"); + font-weight: 300; + font-style: normal; + } + @font-face { + font-family: "NVIDIA"; + src: url("https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Rg.woff") format("woff"), + url("https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Rg.woff2") format("woff2"); + font-weight: 400; + font-style: normal; + } + @font-face { + font-family: "NVIDIA"; + src: url("https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Md.woff") format("woff"), + url("https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Md.woff2") format("woff2"); + font-weight: 500; + font-style: normal; + } + @font-face { + font-family: "NVIDIA"; + src: url("https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Bd.woff") format("woff"), + url("https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Bd.woff2") format("woff2"); + font-weight: 700; + font-style: normal; + } +} + +/* Set up for modern browsers, all weights */ +@supports (font-variation-settings: normal) { + @font-face { + font-family: 'NVIDIA'; + src: url('https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/var/NVIDIASansVF_W_Wght.woff2') format('woff2 supports variations'), + url('https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/var/NVIDIASansVF_W_Wght.woff2') format('woff2-variations'); + font-weight: 100 1000; + font-stretch: 25% 151%; + font-style: normal; + } + @font-face{ + font-family:'NVIDIA'; + src:url('https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/var/NVIDIASansVF_Wght_W_Italic.woff2') format('woff2 supports variations'), + url('https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/var/NVIDIASansVF_Wght_W_Italic.woff2') format('woff2-variations'); + font-weight:100 1000; + font-stretch:25% 151%; + font-style:italic; + } } -.wy-side-nav-search { +:root +{ + /* nv branding */ + --nv-green: #76b900; + --nv-green-illuminate: #76d300; /* button state - hover */ + --nv-black: #000000; + --nv-white: #ffffff; + --nv-green-2: #004831; + + --nv-success: var(--nv-green); + --nv-error: #f44336; + + --nv-font-face: NVIDIA,Arial,Helvetica,Sans-Serif; + --nv-font-face-mono: SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace; + + /* nv branding: light theme */ + --text: #1a1a1a; + --background-default: #ffffff; + --background-alternate: #eeeeee; + --ui-and-graphics: #999999; + + --white: #ffffff; + --gray-1: #f7f7f7; + --gray-2: #eeeeee; + --gray-3: #dddddd; + --gray-4: #cccccc; + + /* nv branding: light theme mobile (closely matches our old font sizes) */ + --h1-color: var(--nv-green); + --h1-font-weight: 100; + --h1-letter-spacing: -0.02em; + --h1-font-size: 36px; + --h1-line-height: 1em; + --h1-text-transform: uppercase; + + --h2-color: var(--nv-green); + --h2-font-weight: 100; + --h2-letter-spacing: -0.02em; + --h2-font-size: 24px; + --h2-line-height: 1em; + --h2-text-transform: uppercase; + + --h3-color: var(--nv-green); + --h3-font-weight: 100; + --h3-letter-spacing: -0.02em; + --h3-font-size: 21px; + --h3-line-height: 1em; + --h3-text-transform: uppercase; + + --h4-color: var(--nv-green); + --h4-font-weight: 100; + --h4-letter-spacing: -0.02em; + --h4-font-size: 18px; + --h4-line-height: 1em; + --h4-text-transform: uppercase; + + --h5-color: var(--nv-green); + --h5-font-size: var(--body-font-size); + + --h6-color: var(--nv-green); + --h6-font-weight: 400; + + --body-font-color: var(--text); + --body-font-weight: normal; + --body-font-size: 16px; + --body-line-height: 1.5em; + + --small-font-color: var(--ui-and-graphics); + --small-font-weight: normal; + --small-font-size: 12px; + --small-line-height: 1.25em; + + --ul-font-color: var(--text); + --ul-font-weight: normal; + --ul-font-size: 16px; + --ul-line-height: 2em; + --ul-marker-font-face: FontAwesome; + --ul-marker-content: '\f105 \00a0 \00a0'; + + --ol-font-color: var(--text); + --ol-font-weight: normal; + --ol-font-size: 16px; + --ol-line-height: 2em; + --ol-list-style-type: decimal; + --ol-ol-list-style-type: upper-alpha; + --ol-ol-ol-list-style-type: decimal; /* not specified in style guide */ + + --disabled-font-color: var(--gray-4); + --disabled-font-weight: normal; + --disabled-font-size: 16px; + --disabled-line-height: 1em; /* style guide says 16px */ + + --error-font-color: var(--nv-error); + --error-font-weight: normal; + --error-font-size: 16px; + --error-line-height: 1em; /* style guide says 16px */ + + --success-font-color: var(--nv-success); + --success-font-weight: normal; + --success-font-size: 16px; + --success-line-height: 1em; /* style guide says 16px */ + + /* omni-style */ + --sidebar-color: #000000; + --sidebar-alt-color: #333333; + --sidebar-headline-color: var(--nv-green); + --sidebar-text-color: #cccccc; + + --table-background-header: var(--nv-black); + --table-background-alternate: var(--background-alternate); /* for alternating rows */ + --table-text: var(--text); + --table-border: var(--ui-and-graphics); + --table-border-header: var(--gray-3); + + /* this is off-brand, but `uppercase` makes headings with source code look bad. */ + --h1-text-transform: none; + --h2-text-transform: none; + --h3-text-transform: none; + --h4-text-transform: none; + + --h3-font-weight: normal; /* this is off-brand and overrides the above definition */ + + --note-background-color: var(--nv-green); + --note-background-alt-color: #cccccc; + + --important-background-color: #f44336; + --important-background-alt-color: #cccccc; + + --link-color: var(--nv-green); + --link-visited-color: var(--nv-green); + --link-hover-color: var(--nv-green-illuminate); + + --background-color: var(--background-default); + + /* template T* tryAcquireInterface(const void* pluginInterface) */ + --api-member-header-background-color: var(--gray-2); + --api-member-header-border-color: var(--sidebar-headline-color); + --api-member-header-text-color: var(--text); + --api-member-header-link-color: var(--link-color); + + --api-member-background-color: var(--gray-1); + + /* struct carb::Framework */ + --api-header-text-color: var(--nv-green); + --api-header-border-color: var(--ui-and-graphics); + + /* sphinx-design color modifications */ + --sd-color-tabs-label-active: var(--nv-green); + --sd-color-tabs-underline-active: var(--nv-green); + + --sd-color-tabs-label-hover: var(--nv-green-illuminate); + --sd-color-tabs-underline-hover: var(--nv-green-illuminate); +} + +/* Custom Styles */ +:root { + --pst-font-size-base: none; + --pst-color-admonition-note: var(--pst-color-primary); + --pst-color-admonition-default: var(--pst-color-primary); + --pst-color-info: 255, 193, 7; + --pst-color-admonition-tip: var(--pst-color-info); + --pst-color-admonition-hint: var(--pst-color-info); + --pst-color-admonition-important: var(--pst-color-info); + --pst-color-warning: 245, 162, 82; + --pst-color-danger: 230, 101, 129; + --pst-color-admonition-warning: var(--pst-color-danger); + --pst-color-link: 118, 185, 0; + --pst-color-inline-code: 92, 22, 130; + --font-family-sans-serif: NVIDIA Sans, Helvetica, Arial, var(--pst-font-family-base-system); + --pst-font-family-heading: NVIDIA Sans, Helvetica, Arial, var(--pst-font-family-base-system); + --pst-font-family-monospace: Roboto Mono, var(--pst-font-family-monospace-system); + font-family: NVIDIA Sans, Helvetica, Arial,Sans-serif; +} + + +html[data-theme="light"] { + --pst-color-primary: var(--nv-green); +} +html[data-theme="dark"] { + --pst-color-primary: var(--nv-green); +} + +/**********************************************************************************************************************/ +/* Standard Text Formatting */ +/**********************************************************************************************************************/ + +/* Headline Formatting */ +.bd-container h1 +{ + color: var(--h1-color); + + font-weight: var(--h1-font-weight); + font-size: var(--h1-font-size); + font-style: normal; + + line-height: var(--h1-line-height); + margin-top: 0.75em; + margin-bottom: 0.75em !important; /* override RTD theme */ + + text-transform: var(--h1-text-transform); +} + +.bd-container h2 +{ + color: var(--h2-color); + + font-weight: var(--h2-font-weight); + font-size: var(--h2-font-size); + font-style: normal; + + line-height: var(--h2-line-height); + margin-top: 1.25em; + margin-bottom: 0.5em !important; /* override RTD theme */ + + text-transform: var(--h2-text-transform); +} + +.bd-container h3 +{ + color: var(--h3-color); + + font-weight: var(--h3-font-weight); + font-size: var(--h3-font-size); + font-style: normal; + + line-height: var(--h3-line-height); + margin-top: 1.25em; + margin-bottom: 0.5em !important; /* override RTD theme */ + + text-transform: var(--h3-text-transform); +} + +.bd-container h4 +{ + color: var(--h4-color); + + font-weight: var(--h4-font-weight); + font-size: var(--h4-font-size); + font-style: normal; + + line-height: var(--h4-line-height); + margin-top: 1.25em; + margin-bottom: 0.5em !important; /* override RTD theme */ + + text-transform: var(--h4-text-transform); +} + +.bd-container h5 +{ + color: var(--h5-color); + + font-size: var(--h5-font-size); +} + +.bd-container h6 +{ + color: var(--h6-color); + + font-weight: var(--h6-font-weight); +} + +/* Math should inherit its color */ +span[id*=MathJax-Span] +{ + color: inherit; +} + +/* text highlighted by search */ +.rst-content .highlighted +{ + background: #f1c40f3b; + box-shadow: 0 0 0 1px #f1c40f; + display: inline; + font-weight: inherit; +} + +/* a local table-of-contents messes with heading colors. make sure to use the regular heading colors */ +.rst-content .toc-backref +{ + color: inherit; +} + +/* make links to function looks like other literals */ +.rst-content code.xref, +.rst-content tt.xref, +a .rst-content code, +a .rst-content tt +{ + color: #e74c3c; + font-weight: inherit; +} + +/* Link Colors */ +a +{ + color: var(--link-color); +} + +a:visited +{ + color: var(--link-visited-color); +} + +a:hover +{ + color: var(--link-hover-color); +} + +/* follow branding guide for small footer text */ +footer p +{ + color: var(--small-font-color); + font-weight: var(--small-font-weight); + font-size: var(--small-font-size); + line-height: var(--small-line-height); +} + +/* add nvidia logo (like www.nvidia.com) */ +html[data-theme="light"] footer.bd-footer-content p.copyright::before +{ + content: url(../NVIDIA-LogoBlack.svg); + display: block; + width: 110px; + margin: 0px; + position: relative; + left: -9px; +} + +/* add nvidia logo (like www.nvidia.com) */ +html[data-theme="dark"] footer.bd-footer-content p.copyright::before +{ + content: url(../NVIDIA-LogoWhite.svg); display: block; - width: 300px; - padding: .809em; - padding-top: 0.809em; - margin-bottom: .809em; - z-index: 200; - background-color: #2980b9; - text-align: center; - color: #fcfcfc; - padding-top: 40px; -} - -div.banner { - position: fixed; - top: 10px; - left: 20px; - margin: 0; - z-index: 1000; - width: 1050px; - text-align: center; -} - -p.banner { - border-radius: 4px; - color: #004831; - background: #76b900; -} \ No newline at end of file + width: 110px; + margin: 0px; + position: relative; + left: -9px; +} + + +/**********************************************************************************************************************/ +/* Lists */ +/**********************************************************************************************************************/ + +/* unordered list should have a nv-green > */ +.rst-content section ul:not(.treeView):not(.collapsibleList) li:not(.collapsibleListClosed):not(.collapsibleListOpen):not(.lastChild)::marker, +.rst-content .toctree-wrapper ul li::marker, +.wy-plain-list-disc li::marker, +article ul li::marker +{ + font-family: var(--ul-marker-font-face); + content: var(--ul-marker-content); + color: var(--nv-green); + font-weight: 600; +} + +/* top-level ordered list should have a nv-green number */ +.rst-content section ol li::marker, +.rst-content ol.arabic li::marker, +.wy-plain-list-decimal li::marker, +article ol li::marker +{ + color: var(--nv-green); + font-weight: 600; + list-style: var(--ol-list-style-type); +} + +/* second-level ordered list should have a nv-green uppercase letter */ +.rst-content section ol ol li, +.rst-content ol.arabic ol.arabic li, +.wy-plain-list-decimal ol ol li, +article ol ol li +{ + list-style: var(--ol-ol-list-style-type); +} + +/* third-level ordered lists aren't in the branding guide. let's use numbers. */ +.rst-content section ol ol ol li, +.rst-content ol.arabic ol.arabic ol li, +.wy-plain-list-decimal ol ol ol li, +article ol ol ol li +{ + list-style: var(--ol-ol-ol-list-style-type); +} + +/* start the first paragraph immediately (don't add space at the top) */ +dd p:first-child +{ + margin-top: 0px; +} diff --git a/docs/source/_static/css/versions.css b/docs/source/_static/css/versions.css new file mode 100644 index 0000000000..cafebc54ba --- /dev/null +++ b/docs/source/_static/css/versions.css @@ -0,0 +1,140 @@ +/* Version Switcher */ + +.rst-versions { + flex-align: bottom; + bottom: 0; + left: 0; + z-index: 400 +} + +.rst-versions a { + color: var(--nv-green); + text-decoration: none +} + +.rst-versions .rst-badge-small { + display: none +} + +.rst-versions .rst-current-version { + padding: 12px; + display: block; + text-align: right; + font-size: 90%; + cursor: pointer; + border-top: 1px solid rgba(0,0,0,.1); + *zoom:1 +} + +.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after { + display: table; + content: "" +} + +.rst-versions .rst-current-version:after { + clear: both +} + +.rst-versions .rst-current-version .fa-book { + float: left +} + +.rst-versions .rst-current-version .icon-book { + float: left +} + +.rst-versions .rst-current-version.rst-out-of-date { + background-color: #E74C3C; + color: #fff +} + +.rst-versions .rst-current-version.rst-active-old-version { + background-color: #F1C40F; + color: #000 +} + +.rst-versions.shift-up { + height: auto; + max-height: 100% +} + +.rst-versions.shift-up .rst-other-versions { + display: block +} + +.rst-versions .rst-other-versions { + font-size: 90%; + padding: 12px; + color: gray; + display: none +} + +.rst-versions .rst-other-versions hr { + display: block; + height: 1px; + border: 0; + margin: 20px 0; + padding: 0; + border-top: solid 1px #413d3d +} + +.rst-versions .rst-other-versions dd { + display: inline-block; + margin: 0 +} + +.rst-versions .rst-other-versions dd a { + display: inline-block; + padding: 6px; + color: var(--nv-green); + font-weight: 500; +} + +.rst-versions.rst-badge { + width: auto; + bottom: 20px; + right: 20px; + left: auto; + border: none; + max-width: 300px +} + +.rst-versions.rst-badge .icon-book { + float: none +} + +.rst-versions.rst-badge .fa-book { + float: none +} + +.rst-versions.rst-badge.shift-up .rst-current-version { + text-align: right +} + +.rst-versions.rst-badge.shift-up .rst-current-version .fa-book { + float: left +} + +.rst-versions.rst-badge.shift-up .rst-current-version .icon-book { + float: left +} + +.rst-versions.rst-badge .rst-current-version { + width: auto; + height: 30px; + line-height: 30px; + padding: 0 6px; + display: block; + text-align: center +} + +@media screen and (max-width: 768px) { + .rst-versions { + width:85%; + display: none + } + + .rst-versions.shift { + display: block + } +} diff --git a/docs/source/_static/favicon.png b/docs/source/_static/favicon.png new file mode 100755 index 0000000000..a00f862ecf Binary files /dev/null and b/docs/source/_static/favicon.png differ diff --git a/docs/source/_static/js/rtd-version-switcher.js b/docs/source/_static/js/rtd-version-switcher.js new file mode 100644 index 0000000000..6b4ccd0f1f --- /dev/null +++ b/docs/source/_static/js/rtd-version-switcher.js @@ -0,0 +1,5 @@ +var jQuery = (typeof(window) != 'undefined') ? window.jQuery : require('jquery'); +var doc = $(document); +doc.on('click', "[data-toggle='rst-current-version']", function() { + $("[data-toggle='rst-versions']").toggleClass("shift-up"); +}); diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 76917f64c1..b053113ba3 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -1,9 +1,21 @@ -{% extends "!layout.html" %} -{% block extrabody %} - -{% endblock %} +{%- extends "!layout.html" %} + +{%- block extrahead %} + {%- if analytics_id %} + + + + {% endif %} + + + + +{%- endblock %} diff --git a/docs/source/_templates/merlin-ecosystem.html b/docs/source/_templates/merlin-ecosystem.html new file mode 100644 index 0000000000..c925bb1442 --- /dev/null +++ b/docs/source/_templates/merlin-ecosystem.html @@ -0,0 +1,14 @@ + diff --git a/docs/source/_templates/versions.html b/docs/source/_templates/versions.html index 31a1257898..26e2a32fcd 100644 --- a/docs/source/_templates/versions.html +++ b/docs/source/_templates/versions.html @@ -1,7 +1,7 @@ {%- if current_version %}
- Other Versions + v: {{ current_version.name }} diff --git a/docs/source/conf.py b/docs/source/conf.py index a062665709..e110990e32 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -42,7 +42,7 @@ # -- Project information ----------------------------------------------------- project = "Merlin Models" -copyright = "2022, NVIDIA" +copyright = "2023, NVIDIA" author = "NVIDIA" @@ -54,7 +54,6 @@ extensions = [ "myst_nb", "sphinx_multiversion", - "sphinx_rtd_theme", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.coverage", @@ -96,17 +95,34 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = "sphinx_rtd_theme" +html_theme = "sphinx_book_theme" +html_title = "NVIDIA Merlin Models" +html_favicon = "_static/favicon.png" html_theme_options = { - "navigation_depth": 2, - "analytics_id": "G-NVJ1Y1YJHK", + "repository_url": "https://github.com/NVIDIA-Merlin/models", + "use_repository_button": True, + "footer_content_items": ["copyright.html", "last-updated.html"], + "extra_footer": "", + "logo": {"text": "NVIDIA Merlin Models", "alt_text": "NVIDIA Merlin Models"}, } +html_sidebars = { + "**": [ + "navbar-logo.html", + "search-field.html", + "icon-links.html", + "sbt-sidebar-nav.html", + "merlin-ecosystem.html", + "versions.html", + ] +} +html_context = {"analytics_id": "G-NVJ1Y1YJHK"} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] -html_css_files = ["css/custom.css"] +html_css_files = ["css/custom.css", "css/versions.css"] +html_js_files = ["js/rtd-version-switcher.js"] source_suffix = [".rst", ".md"] @@ -126,7 +142,6 @@ smv_refs_override_suffix = r"-docs" -html_sidebars = {"**": ["versions.html"]} html_baseurl = "https://nvidia-merlin.github.io/models/stable/" intersphinx_mapping = { diff --git a/examples/05-Retrieval-Model.ipynb b/examples/05-Retrieval-Model.ipynb index c2553347e5..5306610e01 100644 --- a/examples/05-Retrieval-Model.ipynb +++ b/examples/05-Retrieval-Model.ipynb @@ -1616,7 +1616,8 @@ } ], "source": [ - "queries = model.query_embeddings(Dataset(user_features, schema=schema), batch_size=1024, index=Tags.USER_ID)\n", + "queries = model.query_embeddings(Dataset(user_features, schema=schema.select_by_tag(Tags.USER)), \n", + " batch_size=1024, index=Tags.USER_ID)\n", "query_embs_df = queries.compute(scheduler=\"synchronous\").reset_index()" ] }, @@ -1996,7 +1997,8 @@ } ], "source": [ - "item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema), batch_size=1024, index=Tags.ITEM_ID)" + "item_embs = model.candidate_embeddings(Dataset(item_features, schema=schema.select_by_tag(Tags.ITEM)), \n", + " batch_size=1024, index=Tags.ITEM_ID)" ] }, { @@ -2460,7 +2462,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.8.10" }, "merlin": { "containers": [ diff --git a/examples/usecases/ecommerce-session-based-next-item-prediction-for-fashion.ipynb b/examples/08-Train-a-model-for-session-based-next-item-prediction.ipynb similarity index 99% rename from examples/usecases/ecommerce-session-based-next-item-prediction-for-fashion.ipynb rename to examples/08-Train-a-model-for-session-based-next-item-prediction.ipynb index 6769dae3f2..e6c7ea1a04 100644 --- a/examples/usecases/ecommerce-session-based-next-item-prediction-for-fashion.ipynb +++ b/examples/08-Train-a-model-for-session-based-next-item-prediction.ipynb @@ -41,6 +41,8 @@ "\n", "NVIDIA-Merlin team participated in [Recsys2022 challenge](http://www.recsyschallenge.com/2022/index.html) and secured 3rd position. This notebook contains the various techniques used in the solution.\n", "\n", + "In this notebook we train several different architectures with the last one being a transformer model. We only cover training. If you would be interested also in putting your model in production and serving predictions using the industry standard Triton Inference Server, please consult [this notebook](https://github.com/NVIDIA-Merlin/Merlin/blob/main/examples/Next-Item-Prediction-with-Transformers/tf/transformers-next-item-prediction.ipynb).\n", + "\n", "### Learning Objective\n", "\n", "In this notebook, we will apply important concepts that improve recommender systems. We leveraged them for our RecSys solution:\n", @@ -860,7 +862,7 @@ "\n", "We train a Sequential-Multi-Layer Perceptron model, which averages the sequential input features (e.g. `item_id_list_seq`) and concatenate the resulting embeddings with the categorical embeddings (e.g. `item_id_last`). We visualize the architecture in the figure below.\n", "\n", - "" + "" ] }, { @@ -1285,7 +1287,7 @@ "source": [ "In this section, we train a Bi-LSTM model, an extension of traditional LSTMs, which enables straight (past) and reverse traversal of input (future) sequence to be used. The input block concatenates the embedding vectors for all sequential features (`item_id_list_seq`, `f_47_list_seq`, `f_68_list_seq`) per step (e.g. here 3). The concatenated vectors are processed by a BiLSTM architecture. The hidden state of the BiLSTM is concatenated with the embedding vectors of the categorical features (`item_id_last`). Then we connect it with a Multi-Layer Perceptron Block. We visualize the architecture in the figure below.\n", "\n", - "" + "" ] }, { @@ -2093,7 +2095,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.10 ('merlin_22.07_dev')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/examples/usecases/transformers-next-item-prediction-with-pretrained-embeddings.ipynb b/examples/usecases/transformers-next-item-prediction-with-pretrained-embeddings.ipynb deleted file mode 100644 index 1632331ee5..0000000000 --- a/examples/usecases/transformers-next-item-prediction-with-pretrained-embeddings.ipynb +++ /dev/null @@ -1,1433 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "5b545747", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#`\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions anda\n", - "# limitations under the License.\n", - "# ==============================================================================\n", - "\n", - "# Each user is responsible for checking the content of datasets and the\n", - "# applicable licenses and determining if suitable for the intended use." - ] - }, - { - "cell_type": "markdown", - "id": "5ec6d3b8", - "metadata": {}, - "source": [ - "\n", - "\n", - "# Transformer-based architecture for next-item prediction task with pretrained embeddings\n", - "\n", - "This notebook is created using the latest stable [merlin-tensorflow](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-tensorflow/tags) container.\n", - "\n", - "## Overview\n", - "\n", - "In this use case we will train a Transformer-based architecture for next-item prediction task with pretrained embeddings.\n", - "\n", - "**You can chose to download the full dataset manually or use synthetic data.**\n", - "\n", - "We will use the [SIGIR eCOM 2021 Data Challenge Dataset](https://github.com/coveooss/SIGIR-ecom-data-challenge) to train a session-based model. The dataset contains 36M events of users browsing an online store.\n", - "\n", - "We will reshape the data to organize it into 'sessions'. Each session will be a full customer online journey in chronological order. The goal will be to predict the `url` of the next action taken.\n", - "\n", - "\n", - "### Learning objectives\n", - "\n", - "- Training a Transformer-based architecture for next-item prediction task" - ] - }, - { - "cell_type": "markdown", - "id": "fd2b847f", - "metadata": {}, - "source": [ - "## Downloading and preparing the dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "2dd7827c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-06-20 22:58:36.667322: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX\n", - "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/torch.py:43: UserWarning: PyTorch dtype mappings did not load successfully due to an error: No module named 'torch'\n", - " warn(f\"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}\")\n", - "2023-06-20 22:58:38.026020: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-06-20 22:58:38.026445: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-06-20 22:58:38.026622: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n" - ] - } - ], - "source": [ - "import os\n", - "import cudf\n", - "import numpy as np\n", - "import pandas as pd\n", - "import nvtabular as nvt\n", - "from merlin.schema import ColumnSchema, Schema, Tags\n", - "\n", - "OUTPUT_DATA_DIR = os.environ.get('OUTPUT_DATA_DIR', '/workspace/data')\n", - "NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', 5))\n", - "NUM_EXAMPLES = int(os.environ.get('NUM_EXAMPLES', 100_000))\n", - "MINIMUM_SESSION_LENGTH = int(os.environ.get('MINIMUM_SESSION_LENGTH', 5))" - ] - }, - { - "cell_type": "markdown", - "id": "7fcf7c86", - "metadata": {}, - "source": [ - "You can download the full dataset by registering [here](https://www.coveo.com/en/ailabs/sigir-ecom-data-challenge). If you chose to download the data, please place it alongside this notebook in the `sigir_dataset` directory and extract it.\n", - "\n", - "By default, in this notebook we will be using synthetically generated data based on the SIGIR dataset, but you can run on the full dataset by changing the value of the boolean flag below." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "bc3d1882", - "metadata": {}, - "outputs": [], - "source": [ - "RUN_ON_SYNTHETIC_DATA = True" - ] - }, - { - "cell_type": "markdown", - "id": "68bc6d6d", - "metadata": {}, - "source": [ - "### Clean downloaded data" - ] - }, - { - "cell_type": "markdown", - "id": "9016a3e2", - "metadata": {}, - "source": [ - "If you are training on the full SIGIR dataset, the following code will pre-process it.\n", - "\n", - "Here we deal with `nan` values, drop rows with missing information and parse strings containing lists to lists.\n", - "\n", - "The synthetically generated data is already clean -- it doesn't require this pre-processing." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "428ab049", - "metadata": {}, - "outputs": [], - "source": [ - "if not RUN_ON_SYNTHETIC_DATA:\n", - " train = nvt.Dataset('/workspace/sigir_dataset/train/browsing_train.csv', part_size='500MB')\n", - " skus = nvt.Dataset('/workspace/sigir_dataset/train/sku_to_content.csv')\n", - "\n", - " skus = pd.read_csv('/workspace/sigir_dataset/train/sku_to_content.csv')\n", - "\n", - " skus['description_vector'] = skus['description_vector'].replace(np.nan, '')\n", - " skus['image_vector'] = skus['image_vector'].replace(np.nan, '')\n", - "\n", - " skus['description_vector'] = skus['description_vector'].apply(lambda x: [] if len(x) == 0 else eval(x))\n", - " skus['image_vector'] = skus['image_vector'].apply(lambda x: [] if len(x) == 0 else eval(x))\n", - " skus = skus[skus.description_vector.apply(len) > 0]\n", - " skus = nvt.Dataset(skus)" - ] - }, - { - "cell_type": "markdown", - "id": "9b33fa32", - "metadata": {}, - "source": [ - "### Generate synthetic data" - ] - }, - { - "cell_type": "markdown", - "id": "4c4ba9b9", - "metadata": {}, - "source": [ - "If you are not running on the full dataset, the following lines of code will generate its synthetic counterpart." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "84789211", - "metadata": {}, - "outputs": [], - "source": [ - "if RUN_ON_SYNTHETIC_DATA:\n", - " from merlin.datasets.synthetic import generate_data\n", - "\n", - " train = generate_data('sigir-browsing', NUM_EXAMPLES)\n", - " skus = generate_data('sigir-sku', NUM_EXAMPLES)" - ] - }, - { - "cell_type": "markdown", - "id": "5533f446", - "metadata": {}, - "source": [ - "## Constructing a workflow" - ] - }, - { - "cell_type": "markdown", - "id": "ac47bc4e", - "metadata": {}, - "source": [ - "We need to process our data further before we can use it to train our model.\n", - "\n", - "In particular, the `skus` dataset contains the mapping between the `product_sku_hash` (essentially an item id) to the `description_vector` -- an embedding obtained from the description.\n", - "\n", - "We would like to enable our model to make use of this piece of information. In order to feed this data to our model, we need to map the `product_sku_hash` to an id.\n", - "\n", - "But we need to make sure that the way we process `skus` and the `train` dataset (event information) is consistent, that the same `product_sku_hash` is mapped to the same id both when processing `skus` and `train`.\n", - "\n", - "We do so by defining and fitting a `Categorify` op once and using it to process both the `skus` and the `train` datasets.\n", - "\n", - "Additionally, we apply some further processing to the `train` dataset. We group rows by `session_id_hash` so that each training example will contain events from a single customer visit to the online store arranged in chronological order.\n", - "\n", - "If you would like to learn more about leveraging `NVTabular` to process tabular data on the GPU using a set of industry standard operators, please consult the examples available [here](https://github.com/NVIDIA-Merlin/NVTabular/tree/main/examples).\n", - "\n", - "Let's first process the `train` dataset and retain the `Categorify` operator (`cat_op`) for processing of `skus`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "3b5feee3", - "metadata": {}, - "outputs": [], - "source": [ - "cat_op = nvt.ops.Categorify()\n", - "out = ['product_sku_hash'] >> cat_op >> nvt.ops.TagAsItemID()\n", - "out += ['event_type', 'product_action', 'session_id_hash', 'hashed_url'] >> nvt.ops.Categorify()\n", - "out += ['server_timestamp_epoch_ms'] >> nvt.ops.NormalizeMinMax()\n", - "\n", - "groupby_features = out >> nvt.ops.Groupby(\n", - " groupby_cols=['session_id_hash'],\n", - " aggs={\n", - " 'product_sku_hash': ['list'],\n", - " 'event_type': ['list'],\n", - " 'product_action': ['list'],\n", - " 'hashed_url': ['list', 'count'],\n", - " 'server_timestamp_epoch_ms': ['list']\n", - " },\n", - " sort_cols=\"server_timestamp_epoch_ms\"\n", - ")\n", - "\n", - "filtered_sessions = groupby_features >> nvt.ops.Filter(f=lambda df: df[\"hashed_url_count\"] >= MINIMUM_SESSION_LENGTH)\n", - "\n", - "# We won't be needing the `session_id_hash` nor the `hashed_url_count` any longer\n", - "wf = nvt.Workflow(\n", - " filtered_sessions[\n", - " 'product_sku_hash_list',\n", - " 'event_type_list',\n", - " 'product_action_list',\n", - " 'hashed_url_list',\n", - " ]\n", - ")\n", - "\n", - "# Let's save the output of our workflow -- transformed `train` for later use (training of our model).\n", - "wf.fit_transform(train).to_parquet('train_transformed')" - ] - }, - { - "cell_type": "markdown", - "id": "45a4828e", - "metadata": {}, - "source": [ - "Here are a couple of example rows from `train_transformed`." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "650fb0d0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_sku_hash_listevent_type_listproduct_action_listhashed_url_list
0[578, 972, 378, 420, 328, 126, 233, 925, 410, ...[3, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 3, 4, 4, 4, ...[3, 3, 5, 6, 4, 3, 3, 4, 4, 4, 6, 5, 3, 4, 3, ...[766, 955, 745, 210, 940, 688, 986, 524, 425, ...
1[298, 304, 393, 697, 706, 313, 834, 83, 502, 1...[4, 4, 4, 3, 4, 4, 4, 3, 3, 3, 4, 4, 3, 4, 3, ...[3, 5, 6, 4, 4, 3, 3, 3, 6, 6, 3, 3, 6, 6, 3, ...[13, 221, 915, 658, 456, 378, 802, 180, 580, 4...
2[706, 221, 22, 702, 339, 645, 436, 358, 84, 35...[4, 3, 4, 4, 4, 4, 4, 4, 3, 3, 3, 4, 3, 4, 3, ...[3, 6, 4, 6, 3, 3, 5, 5, 4, 6, 4, 6, 3, 5, 6, ...[271, 940, 562, 498, 172, 239, 270, 215, 489, ...
3[278, 153, 189, 717, 580, 540, 219, 79, 200, 9...[3, 3, 3, 3, 4, 4, 3, 4, 4, 3, 4, 4, 3, 3, 3, ...[6, 6, 6, 6, 3, 4, 4, 4, 4, 4, 3, 6, 5, 4, 3, ...[169, 419, 875, 725, 926, 770, 160, 554, 763, ...
4[156, 922, 914, 592, 842, 916, 137, 928, 615, ...[3, 4, 4, 4, 3, 4, 4, 4, 4, 3, 4, 3, 4, 3, 4, ...[6, 4, 5, 6, 5, 4, 3, 3, 6, 5, 6, 5, 3, 6, 3, ...[318, 506, 281, 191, 506, 480, 965, 399, 761, ...
\n", - "
" - ], - "text/plain": [ - " product_sku_hash_list \\\n", - "0 [578, 972, 378, 420, 328, 126, 233, 925, 410, ... \n", - "1 [298, 304, 393, 697, 706, 313, 834, 83, 502, 1... \n", - "2 [706, 221, 22, 702, 339, 645, 436, 358, 84, 35... \n", - "3 [278, 153, 189, 717, 580, 540, 219, 79, 200, 9... \n", - "4 [156, 922, 914, 592, 842, 916, 137, 928, 615, ... \n", - "\n", - " event_type_list \\\n", - "0 [3, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 3, 4, 4, 4, ... \n", - "1 [4, 4, 4, 3, 4, 4, 4, 3, 3, 3, 4, 4, 3, 4, 3, ... \n", - "2 [4, 3, 4, 4, 4, 4, 4, 4, 3, 3, 3, 4, 3, 4, 3, ... \n", - "3 [3, 3, 3, 3, 4, 4, 3, 4, 4, 3, 4, 4, 3, 3, 3, ... \n", - "4 [3, 4, 4, 4, 3, 4, 4, 4, 4, 3, 4, 3, 4, 3, 4, ... \n", - "\n", - " product_action_list \\\n", - "0 [3, 3, 5, 6, 4, 3, 3, 4, 4, 4, 6, 5, 3, 4, 3, ... \n", - "1 [3, 5, 6, 4, 4, 3, 3, 3, 6, 6, 3, 3, 6, 6, 3, ... \n", - "2 [3, 6, 4, 6, 3, 3, 5, 5, 4, 6, 4, 6, 3, 5, 6, ... \n", - "3 [6, 6, 6, 6, 3, 4, 4, 4, 4, 4, 3, 6, 5, 4, 3, ... \n", - "4 [6, 4, 5, 6, 5, 4, 3, 3, 6, 5, 6, 5, 3, 6, 3, ... \n", - "\n", - " hashed_url_list \n", - "0 [766, 955, 745, 210, 940, 688, 986, 524, 425, ... \n", - "1 [13, 221, 915, 658, 456, 378, 802, 180, 580, 4... \n", - "2 [271, 940, 562, 498, 172, 239, 270, 215, 489, ... \n", - "3 [169, 419, 875, 725, 926, 770, 160, 554, 763, ... \n", - "4 [318, 506, 281, 191, 506, 480, 965, 399, 761, ... " - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "nvt.Dataset('train_transformed', engine='parquet').head()" - ] - }, - { - "cell_type": "markdown", - "id": "18f12dbd", - "metadata": {}, - "source": [ - "Now that we have processed the train set, we can use the mapping preserved in the `cat_op` to process the `skus` dataset containing the embeddings we are after.\n", - "\n", - "Let's now `Categorify` the `product_sku_hash` in `skus` and grab just the description embedding information." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "313808d0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_sku_hashdescription_vectorcategory_hashprice_bucket
013[0.07939800762120258, 0.3465797761609977, -0.3...160.186690
125[0.4275482879608162, -0.30569476366666, 0.1440...380.951997
218[-0.31035419787213536, 0.18070481533058008, 0....220.973384
31[-0.31319783485940356, -0.11623980504981396, -...1380.146260
411[0.25091279302969943, -0.33473442518442525, 0....1190.808252
\n", - "
" - ], - "text/plain": [ - " product_sku_hash description_vector \\\n", - "0 13 [0.07939800762120258, 0.3465797761609977, -0.3... \n", - "1 25 [0.4275482879608162, -0.30569476366666, 0.1440... \n", - "2 18 [-0.31035419787213536, 0.18070481533058008, 0.... \n", - "3 1 [-0.31319783485940356, -0.11623980504981396, -... \n", - "4 11 [0.25091279302969943, -0.33473442518442525, 0.... \n", - "\n", - " category_hash price_bucket \n", - "0 16 0.186690 \n", - "1 38 0.951997 \n", - "2 22 0.973384 \n", - "3 138 0.146260 \n", - "4 119 0.808252 " - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "skus.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "dfad1bcf", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
product_sku_hashdescription_vector
0836[0.07939800762120258, 0.3465797761609977, -0.3...
1979[0.4275482879608162, -0.30569476366666, 0.1440...
211[-0.31035419787213536, 0.18070481533058008, 0....
3469[-0.31319783485940356, -0.11623980504981396, -...
4118[0.25091279302969943, -0.33473442518442525, 0....
\n", - "
" - ], - "text/plain": [ - " product_sku_hash description_vector\n", - "0 836 [0.07939800762120258, 0.3465797761609977, -0.3...\n", - "1 979 [0.4275482879608162, -0.30569476366666, 0.1440...\n", - "2 11 [-0.31035419787213536, 0.18070481533058008, 0....\n", - "3 469 [-0.31319783485940356, -0.11623980504981396, -...\n", - "4 118 [0.25091279302969943, -0.33473442518442525, 0...." - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "out = ['product_sku_hash'] >> cat_op\n", - "wf_skus = nvt.Workflow(out + 'description_vector')\n", - "skus_ds = wf_skus.transform(skus)\n", - "\n", - "skus_ds.head()" - ] - }, - { - "cell_type": "markdown", - "id": "360fe65d", - "metadata": {}, - "source": [ - "Let us now export the embedding information to a `numpy` array and write it to disk.\n", - "\n", - "We will later pass this information to the `Loader` so that it will load the correct emebedding for the product corresponding to a given step of a customer journey.\n", - "\n", - "The embeddings are linked to the train set using the `product_sku_hash` information." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "d99dfdd0", - "metadata": {}, - "outputs": [], - "source": [ - "skus_ds.to_npy('skus.npy')" - ] - }, - { - "cell_type": "markdown", - "id": "58d80879", - "metadata": {}, - "source": [ - "How will the `Loader` know which embedding to associate with a given row of the train set?\n", - "\n", - "The `product_sku_hash` ids have been exported along with the embeddings and are contained in the first column of the output `numpy` array.\n", - "\n", - "Here is the id of the first embedding stored in `skus.npy`:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "d60c6651", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "836.0" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.load('skus.npy')[0, 0]" - ] - }, - { - "cell_type": "markdown", - "id": "974cf669", - "metadata": {}, - "source": [ - "and here is the embedding vector corresponding to `product_sku_hash` of id referenced above:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "c2c111fd", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 0.07939801, 0.34657978, -0.38269496, 0.56307004, -0.10142923,\n", - " 0.03702352, -0.11606304, 0.10070879, -0.21879928, 0.06107687,\n", - " -0.20743195, -0.01330719, 0.60182867, 0.0920322 , 0.2648726 ,\n", - " 0.56061561, 0.48643498, 0.39045152, -0.40012162, 0.09153962,\n", - " -0.38351605, 0.57134731, 0.59986226, -0.40321368, -0.32984972,\n", - " 0.37559494, 0.1554353 , -0.0413067 , 0.33814398, 0.30678041,\n", - " 0.24001132, 0.42737922, 0.41554601, -0.40451691, 0.50428902,\n", - " -0.2004803 , -0.38297056, 0.06580838, 0.48285745, 0.51406472,\n", - " 0.02268894, 0.36343324, 0.32497967, -0.29736346, -0.00538915,\n", - " 0.12329302, -0.04998194, 0.27843002, 0.20212714, 0.39019503])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.load('skus.npy')[0, 1:]" - ] - }, - { - "cell_type": "markdown", - "id": "7b8c4a13", - "metadata": {}, - "source": [ - "We are now ready to construct the `Loader` that will feed the data to our model.\n", - "\n", - "We begin by reading in the embeddings information." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "51e1f766", - "metadata": {}, - "outputs": [], - "source": [ - "embeddings = np.load('skus.npy')" - ] - }, - { - "cell_type": "markdown", - "id": "e0b1f18d", - "metadata": {}, - "source": [ - "We are now ready to define the `Loader`.\n", - "\n", - "We are passing in an `EmbeddingOperator` that will ensure that correct `sku` information (correct `description_vector`) is associated with the correct step in the customer journey (with the lookup key being contained in the `product_sku_hash_list`)\n", - "\n", - "When specifying the dataset, we are creating a `Merlin Dataset` based on the `train_transformed` data we saved above.\n", - "\n", - "Depending on the hardware that you will be running this on and the size of the dataset that you will be using, should you run out of GPU memory, you can specify one of the several parameters that can ease the memory load (`npartitions`, `part_size`, or `part_mem_fraction`).\n", - "\n", - "The `BATCH_SIZE` of 16 should work on a broad set of hardware, but if you are training on a lot of data and your hardware permitting you might want to significantly increase it." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "1d7212fc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.\n", - "[INFO]: sparse_operation_kit is imported\n", - "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.\n", - "[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so\n", - "[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so\n", - "[SOK INFO] Initialize finished, communication tool: horovod\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-06-20 22:58:50.835162: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX\n", - "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2023-06-20 22:58:50.836068: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-06-20 22:58:50.836268: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-06-20 22:58:50.836425: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-06-20 22:58:50.836673: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-06-20 22:58:50.836849: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-06-20 22:58:50.837009: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-06-20 22:58:50.837114: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.\n", - "2023-06-20 22:58:50.837130: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1621] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 24576 MB memory: -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:08:00.0, compute capability: 7.5\n" - ] - } - ], - "source": [ - "BATCH_SIZE = 16\n", - "\n", - "from merlin.dataloader.tensorflow import Loader\n", - "from merlin.dataloader.ops.embeddings import EmbeddingOperator\n", - "import merlin.models.tf as mm\n", - "\n", - "embedding_operator = EmbeddingOperator(\n", - " embeddings[:, 1:].astype(np.float32),\n", - " id_lookup_table=embeddings[:, 0].astype(int),\n", - " lookup_key=\"product_sku_hash_list\",\n", - " embedding_name='product_embeddings'\n", - ")\n", - "\n", - "loader = Loader(\n", - " dataset=nvt.Dataset('train_transformed', engine='parquet'),\n", - " batch_size=BATCH_SIZE,\n", - " transforms=[\n", - " embedding_operator\n", - " ],\n", - " shuffle=True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "4f037d5d", - "metadata": {}, - "source": [ - "Using the `EmbeddingOperator` object we referenced our `product_embeddings` and insructed the model what to use as a key to look up the information.\n", - "\n", - "Below is an example batch of data that our model will consume." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "7371e23d", - "metadata": {}, - "outputs": [], - "source": [ - "batch = mm.sample_batch(loader, batch_size=BATCH_SIZE, include_targets=False, prepare_features=True)" - ] - }, - { - "cell_type": "markdown", - "id": "f7c9a50d", - "metadata": {}, - "source": [ - "`product_embeddings` are included in the batch." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "3cbf8ea4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['product_sku_hash_list', 'event_type_list', 'product_action_list', 'hashed_url_list', 'product_embeddings'])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "batch.keys()" - ] - }, - { - "cell_type": "markdown", - "id": "53e61e71", - "metadata": {}, - "source": [ - "## Creating and training the model" - ] - }, - { - "cell_type": "markdown", - "id": "2461926e", - "metadata": {}, - "source": [ - "We are now ready to construct our model." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "6867c8ba", - "metadata": {}, - "outputs": [], - "source": [ - "import merlin.models.tf as mm\n", - "\n", - "input_block = mm.InputBlockV2(\n", - " loader.output_schema,\n", - " embeddings=mm.Embeddings(\n", - " loader.output_schema.select_by_tag(Tags.CATEGORICAL),\n", - " sequence_combiner=None,\n", - " ),\n", - " pretrained_embeddings=mm.PretrainedEmbeddings(\n", - " loader.output_schema.select_by_tag(Tags.EMBEDDING),\n", - " sequence_combiner=None,\n", - " normalizer=\"l2-norm\",\n", - " output_dims={\"product_embeddings\": 64},\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "cafb788f", - "metadata": {}, - "source": [ - "We have now constructed an `input_block` that will take our batch and transform it in a fashion that will make it amenable for further processing by subsequent layers of our model.\n", - "\n", - "To test that everything has worked, we can pass our example `batch` through the `input_block`." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "3f8afa56", - "metadata": {}, - "outputs": [], - "source": [ - "input_batch = input_block(batch)" - ] - }, - { - "cell_type": "markdown", - "id": "d24a70fe", - "metadata": {}, - "source": [ - "Let us now construct the remaining layers of our model." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "78b21c0f", - "metadata": {}, - "outputs": [], - "source": [ - "target = 'hashed_url_list'\n", - "\n", - "# We do not need the `train_transformed` dataset here, but we do need\n", - "# to access the schema.\n", - "# It contains important information that will help our model construct itself.\n", - "schema = nvt.Dataset('train_transformed', engine='parquet').schema\n", - "\n", - "dmodel=64\n", - "mlp_block = mm.MLPBlock(\n", - " [128,dmodel],\n", - " activation='relu',\n", - " no_activation_last_layer=True,\n", - " )\n", - "transformer_block = mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2)\n", - "model = mm.Model(\n", - " input_block,\n", - " mlp_block,\n", - " transformer_block,\n", - " mm.CategoricalOutput(\n", - " schema.select_by_name(target),\n", - " default_loss=\"categorical_crossentropy\",\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "13b54d19", - "metadata": {}, - "source": [ - "And let us train it." - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "fbb03f0c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n", - " warnings.warn(\n", - "2023-06-20 22:58:58.950175: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8700\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/5\n", - "WARNING:tensorflow:Gradients do not exist for variables ['model/mask_emb:0', 'transformer/layer_._0/rel_attn/r_s_bias:0', 'transformer/layer_._0/rel_attn/seg_embed:0', 'transformer/layer_._1/rel_attn/r_s_bias:0', 'transformer/layer_._1/rel_attn/seg_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?\n", - "WARNING:tensorflow:Gradients do not exist for variables ['model/mask_emb:0', 'transformer/layer_._0/rel_attn/r_s_bias:0', 'transformer/layer_._0/rel_attn/seg_embed:0', 'transformer/layer_._1/rel_attn/r_s_bias:0', 'transformer/layer_._1/rel_attn/seg_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-06-20 22:59:11.285571: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: model/xl_net_block/sequential_block_7/replace_masked_embeddings/RaggedWhere/Assert/AssertGuard/branch_executed/_95\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "18/18 [==============================] - 42s 2s/step - loss: 6.9800 - recall_at_10: 0.0106 - mrr_at_10: 0.0033 - ndcg_at_10: 0.0050 - map_at_10: 0.0033 - precision_at_10: 0.0011 - regularization_loss: 0.0000e+00 - loss_batch: 6.9689\n", - "Epoch 2/5\n", - "18/18 [==============================] - 34s 2s/step - loss: 6.9591 - recall_at_10: 0.0106 - mrr_at_10: 0.0031 - ndcg_at_10: 0.0048 - map_at_10: 0.0031 - precision_at_10: 0.0011 - regularization_loss: 0.0000e+00 - loss_batch: 6.9363\n", - "Epoch 3/5\n", - "18/18 [==============================] - 39s 2s/step - loss: 6.9471 - recall_at_10: 0.0107 - mrr_at_10: 0.0028 - ndcg_at_10: 0.0046 - map_at_10: 0.0028 - precision_at_10: 0.0011 - regularization_loss: 0.0000e+00 - loss_batch: 6.9206\n", - "Epoch 4/5\n", - "18/18 [==============================] - 38s 2s/step - loss: 6.9398 - recall_at_10: 0.0103 - mrr_at_10: 0.0030 - ndcg_at_10: 0.0047 - map_at_10: 0.0030 - precision_at_10: 0.0010 - regularization_loss: 0.0000e+00 - loss_batch: 6.9015\n", - "Epoch 5/5\n", - "18/18 [==============================] - 38s 2s/step - loss: 6.9375 - recall_at_10: 0.0104 - mrr_at_10: 0.0030 - ndcg_at_10: 0.0047 - map_at_10: 0.0030 - precision_at_10: 0.0010 - regularization_loss: 0.0000e+00 - loss_batch: 6.9095\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.compile(run_eagerly=False, optimizer='adam', loss=\"categorical_crossentropy\")\n", - "model.fit(loader, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS, pre=mm.SequenceMaskRandom(schema=loader.output_schema, target=target, masking_prob=0.3, transformer=transformer_block))" - ] - }, - { - "cell_type": "markdown", - "id": "fa8ab17b", - "metadata": {}, - "source": [ - "## Serving predictions" - ] - }, - { - "cell_type": "markdown", - "id": "c778420d", - "metadata": {}, - "source": [ - "Now that we have prepared a workflow for processing our data (`wf`), defined the embedding operator (`embedding_operator`) and trained our model (`model`), we have all the components we need to serve our model using the Triton Inference Server (TIS).\n", - "\n", - "Let us define a set of inference operators (a pipeline for processing our data all the way to obtaining predictions) and export them as an ensemble that we will be able to serve using TIS." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "18f19033", - "metadata": {}, - "outputs": [], - "source": [ - "from merlin.systems.dag.ops.tensorflow import PredictTensorflow\n", - "from merlin.systems.dag.ensemble import Ensemble\n", - "from merlin.systems.dag.ops.workflow import TransformWorkflow" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "385aba04", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n", - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n", - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, sequence_mask_random_layer_call_fn, sequence_mask_random_layer_call_and_return_conditional_losses, prepare_list_features_1_layer_call_fn while saving (showing 5 of 110). These functions will not be directly callable after loading.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: /tmp/tmpi3g8g7q7/assets\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: /tmp/tmpi3g8g7q7/assets\n" - ] - } - ], - "source": [ - "inference_operators = wf.input_schema.column_names >> TransformWorkflow(wf) >> embedding_operator >> PredictTensorflow(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "1c14a25d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (product_sku_hash_list): TensorShape([16, None, 1])\n", - " (event_type_list): TensorShape([16, None, 1])\n", - " (product_action_list): TensorShape([16, None, 1])\n", - " (hashed_url_list): TensorShape([16, None, 1])\n", - " (product_embeddings): TensorShape([16, None, 50])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (product_sku_hash_list): tf.int64\n", - " (event_type_list): tf.int64\n", - " (product_action_list): tf.int64\n", - " (hashed_url_list): tf.int64\n", - " (product_embeddings): tf.float32\n", - " )\n", - "), because it is not built.\n", - "WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, sequence_mask_random_layer_call_fn, sequence_mask_random_layer_call_and_return_conditional_losses, prepare_list_features_1_layer_call_fn while saving (showing 5 of 110). These functions will not be directly callable after loading.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets\n", - "/usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:101: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " config[key] = tf.keras.utils.serialize_keras_object(maybe_value)\n", - "/usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:288: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " config[i] = tf.keras.utils.serialize_keras_object(layer)\n", - "/usr/local/lib/python3.8/dist-packages/keras/saving/legacy/saved_model/layer_serialization.py:134: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " return serialization.serialize_keras_object(obj)\n", - "/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initalizer instance more than once.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" - ] - } - ], - "source": [ - "ensemble = Ensemble(inference_operators, wf.input_schema)\n", - "ensemble.export(os.path.join(OUTPUT_DATA_DIR, 'ensemble'));" - ] - }, - { - "cell_type": "markdown", - "id": "264fd1ea", - "metadata": {}, - "source": [ - "After we export the ensemble, we are ready to start the Triton Inference Server.\n", - "\n", - "The server is installed in Merlin Tensorflow and Merlin PyTorch containers. If you are not using one of our containers, then ensure it is installed in your environment. For more information, see the Triton Inference Server [documentation](https://github.com/triton-inference-server/server/blob/r22.03/README.md#documentation).\n", - "\n", - "You can start the server by running the following command:\n", - "\n", - "```tritonserver --model-repository={OUTPUT_DATA_DIR}/ensemble/```\n", - "\n", - "For the --model-repository argument, specify the same value as the `export_path` that you specified previously in the `ensemble.export` method.\n", - "\n", - "After you run the `tritonserver` command, wait until your terminal shows messages like the following example:\n", - "\n", - "I0414 18:29:50.741833 4067 grpc_server.cc:4421] Started GRPCInferenceService at 0.0.0.0:8001
\n", - "I0414 18:29:50.742197 4067 http_server.cc:3113] Started HTTPService at 0.0.0.0:8000
\n", - "I0414 18:29:50.783470 4067 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002\n", - "\n", - "Let us now package our data for inference. We will send 5 rows of data, which corresponds to a single customer journey (session) through the online store. The data will be first processed by the `NVTabular` workflow and subsequentally passed to our transformer model for predicting. " - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "90483210", - "metadata": {}, - "outputs": [], - "source": [ - "# obtaining five rows of data\n", - "df = train.head(5)\n", - "# making sure all the rows correspond to the same online session (have the same `session_id_hash`)\n", - "df['session_id_hash'] = df['session_id_hash'].iloc[0]" - ] - }, - { - "cell_type": "markdown", - "id": "efdf671e", - "metadata": {}, - "source": [ - "Let us now send the data to the Triton Inference Server for inference." - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "d8453048", - "metadata": {}, - "outputs": [], - "source": [ - "from merlin.systems.triton import convert_df_to_triton_input\n", - "import tritonclient.grpc as grpcclient\n", - "\n", - "inputs = convert_df_to_triton_input(wf.input_schema, df)\n", - "\n", - "with grpcclient.InferenceServerClient(\"localhost:8001\") as client:\n", - " response = client.infer('executor_model', inputs)" - ] - }, - { - "cell_type": "markdown", - "id": "913b80e8", - "metadata": {}, - "source": [ - "Let's parse the response." - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "4cc4b046", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-2.2332087 , -2.1218574 , -2.390479 , ..., -0.7735352 ,\n", - " 0.1954267 , -0.34523243]], dtype=float32)" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "predictions = response.as_numpy(\"hashed_url_list/categorical_output\")\n", - "predictions" - ] - }, - { - "cell_type": "markdown", - "id": "e49c2ed9", - "metadata": {}, - "source": [ - "The response contains logits predicting the id of the url the customer is most likely to arrive at as next step of their journey through the online store.\n", - "\n", - "Here is the predicted hashed url id:" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "0b9af2ae", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "34" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "predicted_hashed_url_id = predictions.argmax()\n", - "predicted_hashed_url_id" - ] - }, - { - "cell_type": "markdown", - "id": "8ef47efd", - "metadata": {}, - "source": [ - "## Summary\n", - "\n", - "We have trained a transformer model for the next item prediction task using language model masking.\n", - "\n", - "For another session-based example that goes deeper into data preprocessing and that covers several advanced techniques (Weight Tying, Temperature Scaling) please see [Session-Based Next Item Prediction for Fashion E-Commerce](https://github.com/NVIDIA-Merlin/models/blob/t4rec_use_case/examples/usecases/ecommerce-session-based-next-item-prediction-for-fashion.ipynb). " - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/usecases/transformers-next-item-prediction.ipynb b/examples/usecases/transformers-next-item-prediction.ipynb deleted file mode 100644 index b409170f00..0000000000 --- a/examples/usecases/transformers-next-item-prediction.ipynb +++ /dev/null @@ -1,1516 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "a556f660", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions anda\n", - "# limitations under the License.\n", - "# ==============================================================================\n", - "\n", - "# Each user is responsible for checking the content of datasets and the\n", - "# applicable licenses and determining if suitable for the intended use." - ] - }, - { - "cell_type": "markdown", - "id": "697d1452", - "metadata": {}, - "source": [ - "\n", - "\n", - "# Transformer-based architecture for next-item prediction task\n", - "\n", - "This notebook is created using the latest stable [merlin-tensorflow](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/merlin/containers/merlin-tensorflow/tags) container.\n", - "\n", - "## Overview\n", - "\n", - "In this use case we will train a Transformer-based architecture for next-item prediction task.\n", - "\n", - "**Note, the data for this notebook will be automatically downloaded to the folder specified in the cells below.**\n", - "\n", - "We will use the [booking.com dataset](https://github.com/bookingcom/ml-dataset-mdt) to train a session-based model. The dataset contains 1,166,835 of anonymized hotel reservations in the train set and 378,667 in the test set. Each reservation is a part of a customer's trip (identified by `utrip_id`) which includes consecutive reservations.\n", - "\n", - "We will reshape the data to organize it into 'sessions'. Each session will be a full customer itinerary in chronological order. The goal will be to predict the city_id of the final reservation of each trip.\n", - "\n", - "\n", - "### Learning objectives\n", - "\n", - "- Training a Transformer-based architecture for next-item prediction task" - ] - }, - { - "cell_type": "markdown", - "id": "1cccd005", - "metadata": {}, - "source": [ - "## Downloading and preparing the dataset" - ] - }, - { - "cell_type": "markdown", - "id": "1d0b619b", - "metadata": {}, - "source": [ - "We will download the dataset using a functionality provided by merlin models. The dataset can be found on GitHub [here](https://github.com/bookingcom/ml-dataset-mdt).\n", - "\n", - "**Read more about libraries used in the import statements below**\n", - "\n", - "- [get_lib](https://github.com/NVIDIA-Merlin/core/blob/stable/merlin/core/dispatch.py)\n", - "- [get_booking](https://github.com/NVIDIA-Merlin/models/tree/stable/merlin/datasets/ecommerce)\n", - "- [nvtabular](https://github.com/NVIDIA-Merlin/NVTabular/tree/stable/nvtabular)\n", - "- [nvtabular ops](https://github.com/NVIDIA-Merlin/NVTabular/tree/stable/nvtabular/ops)\n", - "- [schema tags](https://github.com/NVIDIA-Merlin/core/blob/stable/merlin/schema/tags.py)\n", - "- [merlin models tensorflow](https://github.com/NVIDIA-Merlin/models/tree/stable/merlin/models/tf)\n", - "- [get_booking](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/datasets/ecommerce/booking/dataset.py)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "40e9ef05", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-05-31 06:06:25.697025: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX\n", - "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/torch.py:43: UserWarning: PyTorch dtype mappings did not load successfully due to an error: No module named 'torch'\n", - " warn(f\"PyTorch dtype mappings did not load successfully due to an error: {exc.msg}\")\n", - "2023-05-31 06:06:26.988036: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-05-31 06:06:26.988386: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-05-31 06:06:26.988518: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[INFO]: sparse_operation_kit is imported\n", - "WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.base has been moved to tensorflow.python.trackable.base. The old module will be deleted in version 2.11.\n", - "[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so\n", - "[SOK INFO] Import /usr/local/lib/python3.8/dist-packages/merlin_sok-1.1.4-py3.8-linux-x86_64.egg/sparse_operation_kit/lib/libsok_experiment.so\n", - "[SOK INFO] Initialize finished, communication tool: horovod\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-05-31 06:06:28.519868: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX\n", - "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2023-05-31 06:06:28.520815: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-05-31 06:06:28.520999: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-05-31 06:06:28.521129: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-05-31 06:06:28.591345: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-05-31 06:06:28.591534: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-05-31 06:06:28.591665: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero\n", - "2023-05-31 06:06:28.591770: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.\n", - "2023-05-31 06:06:28.591778: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:222] Using CUDA malloc Async allocator for GPU: 0\n", - "2023-05-31 06:06:28.591860: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1621] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 24576 MB memory: -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:08:00.0, compute capability: 7.5\n", - "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "# Resetting the TF memory allocation to not be 50% by default. \n", - "import os\n", - "os.environ[\"TF_GPU_ALLOCATOR\"]=\"cuda_malloc_async\"\n", - "\n", - "from merlin.core.dispatch import get_lib\n", - "from merlin.datasets.ecommerce import get_booking\n", - "\n", - "import numpy as np\n", - "import timeit\n", - "\n", - "from nvtabular import *\n", - "from nvtabular import ops\n", - "\n", - "from merlin.schema.tags import Tags\n", - "import merlin.models.tf as mm\n", - "\n", - "INPUT_DATA_DIR = os.environ.get('INPUT_DATA_DIR', '/workspace/data')\n", - "OUTPUT_DATA_DIR = os.environ.get('OUTPUT_DATA_DIR', '/workspace/data')\n", - "NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', '5'))" - ] - }, - { - "cell_type": "markdown", - "id": "c1b42076", - "metadata": {}, - "source": [ - "Let's download the data." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "d0a33352", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.USER_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.SESSION_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "get_booking(INPUT_DATA_DIR)" - ] - }, - { - "cell_type": "markdown", - "id": "ee9dd8c8", - "metadata": {}, - "source": [ - "Each reservation has a unique utrip_id. During each trip a customer vists several destinations." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "01d1b755", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " user_id checkin checkout city_id device_class affiliate_id \\\n", - "0 1000027 2016-08-13 2016-08-14 8183 desktop 7168 \n", - "1 1000027 2016-08-14 2016-08-16 15626 desktop 7168 \n", - "2 1000027 2016-08-16 2016-08-18 60902 desktop 7168 \n", - "3 1000027 2016-08-18 2016-08-21 30628 desktop 253 \n", - "4 1000033 2016-04-09 2016-04-11 38677 mobile 359 \n", - "\n", - " booker_country hotel_country utrip_id \n", - "0 Elbonia Gondal 1000027_1 \n", - "1 Elbonia Gondal 1000027_1 \n", - "2 Elbonia Gondal 1000027_1 \n", - "3 Elbonia Gondal 1000027_1 \n", - "4 Gondal Cobra Island 1000033_1 \n" - ] - } - ], - "source": [ - "# When displaying cudf dataframes use print() or display(), otherwise Jupyter creates hidden copies.\n", - "train = get_lib().read_csv(f'{INPUT_DATA_DIR}/train_set.csv', parse_dates=['checkin', 'checkout'])\n", - "print(train.head())" - ] - }, - { - "cell_type": "markdown", - "id": "fecc2d94", - "metadata": {}, - "source": [ - "We will train on sequences of `city_id` and `booker_country` and based on this information, our model will attempt to predict the next `city_id` (the next hop in the journey).\n", - "\n", - "We will train a transformer model that can work with sequences of variable length within a batch. This functionality is provided to us out of the box and doesn't require any changes to the architecture. Thanks to it we do not have to pad or trim our sequences to any particular length -- our model can make effective use of all of the data!\n", - "\n", - "*With one exception.* For a masked language model that we will be training, we need to discard sequences that are shorter than two hops. This makes sense as there is nothing our model could learn if it was only presented with an itinerary with a single destination on it!\n", - "\n", - "Let us begin by splitting the data into a train and validation set based on trip ID.\n", - "\n", - "Let's see how many unique trips there are in the dataset. Also, let us shuffle the trips along the way so that our validation set consists of a random sample of our train set." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "23bef6ae", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of unique trips is : 217686\n" - ] - } - ], - "source": [ - "# Unique trip ids.\n", - "utrip_ids = train.sample(frac=1).utrip_id.unique()\n", - "print('Number of unique trips is :', len(utrip_ids))" - ] - }, - { - "cell_type": "markdown", - "id": "f7eca1f6", - "metadata": {}, - "source": [ - "Now let's assign data to our train and validation sets. Furthermore, we sort the data by `utrip_id` and `checkin`. This way we ensure our sequences of visited `city_ids` will be in proper order!\n", - "\n", - "Also, let's remove trips where only a single city was visited as they cannot be modeled as a sequence." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "7754847c", - "metadata": {}, - "outputs": [], - "source": [ - "train = get_lib().from_pandas(\n", - " train.to_pandas().join(train.to_pandas().groupby('utrip_id').size().rename('num_examples'), on='utrip_id')\n", - ")\n", - "train = train[train.num_examples > 1]\n", - "\n", - "train.checkin = train.checkin.astype('int')\n", - "train.checkout = train.checkout.astype('int')\n", - "\n", - "train_set_utrip_ids = utrip_ids[:int(0.8 * utrip_ids.shape[0])]\n", - "validation_set_utrip_ids = utrip_ids[int(0.8 * utrip_ids.shape[0]):]\n", - "\n", - "train_set = train[train.utrip_id.isin(train_set_utrip_ids)].sort_values(['utrip_id', 'checkin'])\n", - "validation_set = train[train.utrip_id.isin(validation_set_utrip_ids)].sort_values(['utrip_id', 'checkin'])" - ] - }, - { - "cell_type": "markdown", - "id": "79cc3992", - "metadata": {}, - "source": [ - "## Preprocessing with NVTabular\n", - "\n", - "We can now begin with data preprocessing.\n", - "\n", - "We will combine trips into \"sessions\", discard trips that are too short and calculate total trip length.\n", - "\n", - "We will use NVTabular for this work. It offers optimized tabular data preprocessing operators that run on the GPU. If you would like to learn more about the NVTabular library, please take a look [here](https://github.com/NVIDIA-Merlin/NVTabular).\n", - "\n", - "Read more about the [Merlin's Dataset API](https://github.com/NVIDIA-Merlin/core/blob/stable/merlin/io/dataset.py) \n", - "Read more about how [parquet files are read in and processed by Merlin](https://github.com/NVIDIA-Merlin/core/blob/stable/merlin/io/parquet.py) \n", - "Read more about [Tags](https://github.com/NVIDIA-Merlin/core/blob/stable/merlin/schema/tags.py) \n", - "- [schema_select_by_tag](https://github.com/NVIDIA-Merlin/core/blob/stable/merlin/schema/schema.py) \n", - "\n", - "Read more about [NVTabular Workflows](https://github.com/NVIDIA-Merlin/NVTabular/blob/stable/nvtabular/workflow/workflow.py) \n", - "- [fit_transform](https://github.com/NVIDIA-Merlin/NVTabular/blob/stable/nvtabular/workflow/workflow.py)\n", - "- [transform](https://github.com/NVIDIA-Merlin/NVTabular/blob/stable/nvtabular/workflow/workflow.py) \n", - "\n", - "Read more about the [NVTabular Operators]() \n", - "- [Categorify](https://github.com/NVIDIA-Merlin/NVTabular/blob/stable/nvtabular/ops/categorify.py)\n", - "- [AddTags](https://github.com/NVIDIA-Merlin/NVTabular/blob/stable/nvtabular/ops/add_metadata.py)\n", - "- [LambdaOp](https://github.com/NVIDIA-Merlin/NVTabular/blob/stable/nvtabular/ops/lambdaop.py)\n", - "- [Rename](https://github.com/NVIDIA-Merlin/NVTabular/blob/stable/nvtabular/ops/rename.py)\n", - "- [Filter](https://github.com/NVIDIA-Merlin/NVTabular/blob/stable/nvtabular/ops/filter.py)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3435af68", - "metadata": {}, - "outputs": [], - "source": [ - "train_set_dataset = Dataset(train_set)\n", - "validation_set_dataset = Dataset(validation_set)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "60bd5e59", - "metadata": {}, - "outputs": [], - "source": [ - "weekday_checkin = (\n", - " [\"checkin\"]\n", - " >> ops.LambdaOp(lambda col: get_lib().to_datetime(col).dt.weekday)\n", - " >> ops.Rename(name=\"weekday_checkin\")\n", - ")\n", - "\n", - "weekday_checkout = (\n", - " [\"checkout\"]\n", - " >> ops.LambdaOp(lambda col: get_lib().to_datetime(col).dt.weekday)\n", - " >> ops.Rename(name=\"weekday_checkout\")\n", - ")\n", - "\n", - "categorical_features = (['city_id', 'booker_country', 'hotel_country'] +\n", - " weekday_checkin + weekday_checkout\n", - " ) >> ops.Categorify()\n", - "\n", - "groupby_features = categorical_features + ['utrip_id', 'checkin'] >> ops.Groupby(\n", - " groupby_cols=['utrip_id'],\n", - " aggs={\n", - " 'city_id': ['list', 'count'],\n", - " 'booker_country': ['list'],\n", - " 'hotel_country': ['list'],\n", - " 'weekday_checkin': ['list'],\n", - " 'weekday_checkout': ['list']\n", - " },\n", - " sort_cols=\"checkin\"\n", - ")\n", - "\n", - "list_features = (\n", - " groupby_features['city_id_list', 'booker_country_list', 'hotel_country_list', \n", - " 'weekday_checkin_list', 'weekday_checkout_list'\n", - " ] >> ops.AddTags([Tags.SEQUENCE])\n", - ")\n", - "\n", - "# Filter out sessions with less than 2 interactions \n", - "MINIMUM_SESSION_LENGTH = 2\n", - "features = list_features + (groupby_features['city_id_count'] >> ops.AddTags([Tags.CONTINUOUS]))\n", - "filtered_sessions = features >> ops.Filter(f=lambda df: df[\"city_id_count\"] >= MINIMUM_SESSION_LENGTH) " - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "6105767a", - "metadata": {}, - "outputs": [], - "source": [ - "wf = Workflow(filtered_sessions)\n", - "\n", - "wf.fit_transform(train_set_dataset).to_parquet(os.path.join(OUTPUT_DATA_DIR, 'train_processed.parquet'))\n", - "wf.transform(validation_set_dataset).to_parquet(os.path.join(OUTPUT_DATA_DIR, 'validation_processed.parquet'))\n", - "\n", - "wf.save(os.path.join(OUTPUT_DATA_DIR, 'workflow'))" - ] - }, - { - "cell_type": "markdown", - "id": "539a6675", - "metadata": {}, - "source": [ - "Our data consists of a sequence of visited `city_ids`, a sequence of `booker_countries` (represented as integer categories) and a `city_id_count` column (which contains the count of visited cities in a trip)." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "2dee6b53", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
city_id_listbooker_country_listhotel_country_listweekday_checkin_listweekday_checkout_listcity_id_count
0[8238, 156, 2278, 2097][3, 3, 3, 3][3, 3, 3, 3][5, 7, 4, 3][7, 4, 2, 7]4
1[63, 1160, 87, 618, 63][1, 1, 1, 1, 1][1, 1, 1, 1, 1][5, 1, 4, 3, 5][6, 4, 2, 5, 4]5
2[7, 6, 24, 1050, 65, 52, 3][2, 2, 2, 2, 2, 2, 2][2, 2, 2, 16, 16, 3, 3][5, 1, 2, 6, 5, 7, 4][6, 3, 1, 5, 7, 4, 3]7
3[1032, 757, 140, 3][2, 2, 2, 2][19, 19, 19, 3][1, 4, 2, 3][4, 3, 2, 5]4
4[3603, 262, 662, 250, 359][1, 1, 1, 1, 1][30, 30, 30, 30, 30][1, 3, 6, 5, 1][2, 1, 5, 6, 3]5
\n", - "
" - ], - "text/plain": [ - " city_id_list booker_country_list \\\n", - "0 [8238, 156, 2278, 2097] [3, 3, 3, 3] \n", - "1 [63, 1160, 87, 618, 63] [1, 1, 1, 1, 1] \n", - "2 [7, 6, 24, 1050, 65, 52, 3] [2, 2, 2, 2, 2, 2, 2] \n", - "3 [1032, 757, 140, 3] [2, 2, 2, 2] \n", - "4 [3603, 262, 662, 250, 359] [1, 1, 1, 1, 1] \n", - "\n", - " hotel_country_list weekday_checkin_list weekday_checkout_list \\\n", - "0 [3, 3, 3, 3] [5, 7, 4, 3] [7, 4, 2, 7] \n", - "1 [1, 1, 1, 1, 1] [5, 1, 4, 3, 5] [6, 4, 2, 5, 4] \n", - "2 [2, 2, 2, 16, 16, 3, 3] [5, 1, 2, 6, 5, 7, 4] [6, 3, 1, 5, 7, 4, 3] \n", - "3 [19, 19, 19, 3] [1, 4, 2, 3] [4, 3, 2, 5] \n", - "4 [30, 30, 30, 30, 30] [1, 3, 6, 5, 1] [2, 1, 5, 6, 3] \n", - "\n", - " city_id_count \n", - "0 4 \n", - "1 5 \n", - "2 7 \n", - "3 4 \n", - "4 5 " - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Dataset(os.path.join(OUTPUT_DATA_DIR, 'train_processed.parquet')).head()" - ] - }, - { - "cell_type": "markdown", - "id": "e89cc3a0", - "metadata": {}, - "source": [ - "We are now ready to train our model." - ] - }, - { - "cell_type": "markdown", - "id": "ce95c794", - "metadata": {}, - "source": [ - "Here is the schema of the data that our model will use." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "c4813456", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nametagsdtypeis_listis_raggedproperties.num_bucketsproperties.freq_thresholdproperties.max_sizeproperties.start_indexproperties.cat_pathproperties.domain.minproperties.domain.maxproperties.domain.nameproperties.embedding_sizes.cardinalityproperties.embedding_sizes.dimensionproperties.value_count.minproperties.value_count.max
0city_id_list(Tags.SEQUENCE, Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....TrueTrueNone000.//categories/unique.city_id.parquet037202city_id372035120None
1booker_country_list(Tags.SEQUENCE, Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....TrueTrueNone000.//categories/unique.booker_country.parquet05booker_country6160None
2hotel_country_list(Tags.SEQUENCE, Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....TrueTrueNone000.//categories/unique.hotel_country.parquet0194hotel_country195310None
3weekday_checkin_list(Tags.SEQUENCE, Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....TrueTrueNone000.//categories/unique.weekday_checkin.parquet07weekday_checkin8160None
4weekday_checkout_list(Tags.SEQUENCE, Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....TrueTrueNone000.//categories/unique.weekday_checkout.parquet07weekday_checkout8160None
\n", - "
" - ], - "text/plain": [ - "[{'name': 'city_id_list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.city_id.parquet', 'domain': {'min': 0, 'max': 37202, 'name': 'city_id'}, 'embedding_sizes': {'cardinality': 37203, 'dimension': 512}, 'value_count': {'min': 0, 'max': None}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=None)))), 'is_list': True, 'is_ragged': True}, {'name': 'booker_country_list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.booker_country.parquet', 'domain': {'min': 0, 'max': 5, 'name': 'booker_country'}, 'embedding_sizes': {'cardinality': 6, 'dimension': 16}, 'value_count': {'min': 0, 'max': None}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=None)))), 'is_list': True, 'is_ragged': True}, {'name': 'hotel_country_list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.hotel_country.parquet', 'domain': {'min': 0, 'max': 194, 'name': 'hotel_country'}, 'embedding_sizes': {'cardinality': 195, 'dimension': 31}, 'value_count': {'min': 0, 'max': None}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=None)))), 'is_list': True, 'is_ragged': True}, {'name': 'weekday_checkin_list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.weekday_checkin.parquet', 'domain': {'min': 0, 'max': 7, 'name': 'weekday_checkin'}, 'embedding_sizes': {'cardinality': 8, 'dimension': 16}, 'value_count': {'min': 0, 'max': None}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=None)))), 'is_list': True, 'is_ragged': True}, {'name': 'weekday_checkout_list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.weekday_checkout.parquet', 'domain': {'min': 0, 'max': 7, 'name': 'weekday_checkout'}, 'embedding_sizes': {'cardinality': 8, 'dimension': 16}, 'value_count': {'min': 0, 'max': None}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=None)))), 'is_list': True, 'is_ragged': True}]" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "seq_schema = Workflow.load(os.path.join(OUTPUT_DATA_DIR, 'workflow')).output_schema.select_by_tag(Tags.SEQUENCE)\n", - "seq_schema" - ] - }, - { - "cell_type": "markdown", - "id": "8d422833", - "metadata": {}, - "source": [ - "Let's also identify the target column." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "2b90424a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'city_id_list'" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target = Workflow.load(os.path.join(OUTPUT_DATA_DIR, 'workflow')).output_schema.select_by_tag(Tags.SEQUENCE).column_names[0]\n", - "target" - ] - }, - { - "cell_type": "markdown", - "id": "e9d8adad", - "metadata": {}, - "source": [ - "## Constructing the model" - ] - }, - { - "cell_type": "markdown", - "id": "c4cb17fe", - "metadata": {}, - "source": [ - "Let's construct our model.\n", - "\n", - "We can specify various hyperparameters, such as the number of heads and number of layers to use." - ] - }, - { - "cell_type": "markdown", - "id": "0a460e4c", - "metadata": {}, - "source": [ - "For the transformer portion of our model, we will use the `XLNet` architecture." - ] - }, - { - "cell_type": "markdown", - "id": "23bf02dc", - "metadata": {}, - "source": [ - "Later, when we run the `fit` method on our model, we will specify the `masking_probability` of `0.3` and link it to the transformer block defined in out model. Through the combination of these parameters, our model will train on sequences where any given timestep will be masked with a probability of 0.3 and it will be our model's training task to infer the target value for that step!\n", - "\n", - "To summarize, Masked Language Modeling is implemented by:\n", - "\n", - "* `SequenceMaskRandom()` - Used as a pre for model.fit(), it randomly selects items from the sequence to be masked for prediction as targets, by using Keras masking. This block also adds the necessary configuration to the specified `transformer` block so as it\n", - "is pre-configured with the necessary layers needed to prepare the inputs to the HuggingFace transformer layer and to post-process its outputs. For example, one pre-processing operation is to replace the input embeddings at masked positions for prediction by a dummy trainable embedding, to avoid leakage of the targets.\n", - "\n", - "\n", - "**Read more about the apis used to construct models** \n", - "- [blocks](https://github.com/NVIDIA-Merlin/models/tree/stable/merlin/models/tf/blocks)\n", - "- [MLPBlock](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/blocks/mlp.py)\n", - "- [InputBlockV2](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/inputs/base.py)\n", - "- [Embeddings](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/inputs/embedding.py)\n", - "- [XLNetBlock](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/transformers/block.py)\n", - "- [CategoricalOutput](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/outputs/classification.py)\n", - "- [.schema.select_by_name](https://github.com/NVIDIA-Merlin/core/blob/stable/merlin/schema/schema.py)\n", - "- [.schema.select_by_tag](https://github.com/NVIDIA-Merlin/core/blob/stable/merlin/schema/schema.py)\n", - "- [model.compile()](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/models/base.py)\n", - "- [model.fit()](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/models/base.py)\n", - "- [model.evaluate()](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/models/base.py)\n", - "- [mm.SequenceMaskRandom](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/transforms/sequence.py)\n", - "- [mm.SequenceMaskLast](https://github.com/NVIDIA-Merlin/models/blob/stable/merlin/models/tf/transforms/sequence.py)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "cddfd424", - "metadata": {}, - "outputs": [], - "source": [ - "dmodel=48\n", - "mlp_block = mm.MLPBlock(\n", - " [128,dmodel],\n", - " activation='relu',\n", - " no_activation_last_layer=True,\n", - " )\n", - "transformer_block = mm.XLNetBlock(d_model=dmodel, n_head=4, n_layer=2)\n", - "model = mm.Model(\n", - " mm.InputBlockV2(\n", - " seq_schema,\n", - " embeddings=mm.Embeddings(\n", - " Workflow.load(os.path.join(OUTPUT_DATA_DIR, 'workflow')).output_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None\n", - " ),\n", - " ),\n", - " mlp_block,\n", - " transformer_block,\n", - " mm.CategoricalOutput(\n", - " Workflow.load(os.path.join(OUTPUT_DATA_DIR, 'workflow')).output_schema.select_by_name(target),\n", - " default_loss=\"categorical_crossentropy\",\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "aac975cd", - "metadata": {}, - "source": [ - "## Model training" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "65d28c27", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/5\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-05-31 06:06:44.034041: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8700\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Gradients do not exist for variables ['model/mask_emb:0', 'transformer/layer_._0/rel_attn/r_s_bias:0', 'transformer/layer_._0/rel_attn/seg_embed:0', 'transformer/layer_._1/rel_attn/r_s_bias:0', 'transformer/layer_._1/rel_attn/seg_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?\n", - "WARNING:tensorflow:Gradients do not exist for variables ['model/mask_emb:0', 'transformer/layer_._0/rel_attn/r_s_bias:0', 'transformer/layer_._0/rel_attn/seg_embed:0', 'transformer/layer_._1/rel_attn/r_s_bias:0', 'transformer/layer_._1/rel_attn/seg_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-05-31 06:06:54.541024: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: model/xl_net_block/sequential_block_5/replace_masked_embeddings/RaggedWhere/Assert/AssertGuard/branch_executed/_95\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2720/2720 [==============================] - 81s 25ms/step - loss: 7.3315 - recall_at_10: 0.1973 - mrr_at_10: 0.0863 - ndcg_at_10: 0.1123 - map_at_10: 0.0863 - precision_at_10: 0.0197 - regularization_loss: 0.0000e+00 - loss_batch: 7.3306\n", - "Epoch 2/5\n", - "2720/2720 [==============================] - 70s 25ms/step - loss: 6.0979 - recall_at_10: 0.3633 - mrr_at_10: 0.1707 - ndcg_at_10: 0.2161 - map_at_10: 0.1707 - precision_at_10: 0.0363 - regularization_loss: 0.0000e+00 - loss_batch: 6.0950\n", - "Epoch 3/5\n", - "2720/2720 [==============================] - 71s 26ms/step - loss: 5.5827 - recall_at_10: 0.4306 - mrr_at_10: 0.2056 - ndcg_at_10: 0.2588 - map_at_10: 0.2056 - precision_at_10: 0.0431 - regularization_loss: 0.0000e+00 - loss_batch: 5.5806\n", - "Epoch 4/5\n", - "2720/2720 [==============================] - 72s 26ms/step - loss: 5.3211 - recall_at_10: 0.4627 - mrr_at_10: 0.2213 - ndcg_at_10: 0.2784 - map_at_10: 0.2213 - precision_at_10: 0.0463 - regularization_loss: 0.0000e+00 - loss_batch: 5.3194\n", - "Epoch 5/5\n", - "2720/2720 [==============================] - 71s 26ms/step - loss: 5.1920 - recall_at_10: 0.4787 - mrr_at_10: 0.2306 - ndcg_at_10: 0.2892 - map_at_10: 0.2306 - precision_at_10: 0.0479 - regularization_loss: 0.0000e+00 - loss_batch: 5.1903\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model.compile(run_eagerly=False, optimizer='adam', loss=\"categorical_crossentropy\")\n", - "\n", - "model.fit(\n", - " Dataset(os.path.join(OUTPUT_DATA_DIR, 'train_processed.parquet')),\n", - " batch_size=64,\n", - " epochs=NUM_EPOCHS,\n", - " pre=mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3, transformer=transformer_block)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "24699106", - "metadata": {}, - "source": [ - "## Model evaluation" - ] - }, - { - "cell_type": "markdown", - "id": "73d87d27", - "metadata": {}, - "source": [ - "We have trained our model.\n", - "\n", - "But in training the metrics come from a masked language modelling task. A portion of steps in the sequence was masked for each example. The metrics were calculated on this task.\n", - "\n", - "In reality, we probably care how well our model does on the next item prediction task (as it mimics the scenario in which the model would be likely to be used).\n", - "\n", - "Let's measure the performance of the model on a task where it attempts to predict the last item in a sequence.\n", - "\n", - "We will mask the last item using `SequenceMaskLast` and run inference." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "bb3c6358", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-05-31 06:12:51.968982: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: model/xl_net_block/sequential_block_5/replace_masked_embeddings/RaggedWhere/Assert/AssertGuard/branch_executed/_74\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "340/340 [==============================] - 11s 20ms/step - loss: 4.7151 - recall_at_10: 0.5533 - mrr_at_10: 0.3083 - ndcg_at_10: 0.3665 - map_at_10: 0.3083 - precision_at_10: 0.0553 - regularization_loss: 0.0000e+00 - loss_batch: 4.7149\n" - ] - } - ], - "source": [ - "metrics = model.evaluate(\n", - " Dataset(os.path.join(OUTPUT_DATA_DIR, 'validation_processed.parquet')),\n", - " batch_size=128,\n", - " pre=mm.SequenceMaskLast(schema=seq_schema, target=target, transformer=transformer_block),\n", - " return_dict=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "83ca276f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'loss': 4.715089797973633,\n", - " 'recall_at_10': 0.5533444881439209,\n", - " 'mrr_at_10': 0.30831339955329895,\n", - " 'ndcg_at_10': 0.36654922366142273,\n", - " 'map_at_10': 0.30831339955329895,\n", - " 'precision_at_10': 0.055334459990262985,\n", - " 'regularization_loss': 0.0,\n", - " 'loss_batch': 4.635858535766602}" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metrics" - ] - }, - { - "cell_type": "markdown", - "id": "9fb5cc29", - "metadata": {}, - "source": [ - "## Serving predictions using the Triton Inference Server" - ] - }, - { - "cell_type": "markdown", - "id": "9dc6ee5f", - "metadata": {}, - "source": [ - "Now, we will serve our trained models on [NVIDIA Triton Inference Server (TIS)](https://github.com/triton-inference-server/server). TIS is an open-source inference serving software that helps standardize model deployment and execution and delivers fast and scalable AI in production. To serve recommender models on TIS easily, NVIDIA Merlin team designed and developed [the Merlin Systems library](https://github.com/NVIDIA-Merlin/systems). Merlin Systems provides tools and operators to be able to serve end-to-end recommender systems pipelines on TIS easily\n", - "\n", - "In order to perform inference on the Triton Inference Server, we need to output the inference operators to disk.\n", - "\n", - "The inference operators form an `Ensemble`, which is a pipeline that takes in raw data, processes it using NVTabular, and finally outputs predictions from the model that we trained.\n", - "\n", - "Let's write the `Ensemble` to disk (we will later load it on Triton to perform inference)." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "7ae33813", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n", - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n", - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, sequence_mask_random_layer_call_fn, sequence_mask_random_layer_call_and_return_conditional_losses, sequence_mask_last_layer_call_fn while saving (showing 5 of 108). These functions will not be directly callable after loading.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: /tmp/tmp1sakw940/model.savedmodel/assets\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: /tmp/tmp1sakw940/model.savedmodel/assets\n", - "/usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:101: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " config[key] = tf.keras.utils.serialize_keras_object(maybe_value)\n", - "/usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:288: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " config[i] = tf.keras.utils.serialize_keras_object(layer)\n", - "/usr/local/lib/python3.8/dist-packages/keras/saving/legacy/saved_model/layer_serialization.py:134: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " return serialization.serialize_keras_object(obj)\n", - "/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n", - "/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n", - "/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'city_id_count', which is not being used by any downstream operator in the ensemble graph.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer TFSharedEmbeddings(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:Skipping full serialization of Keras layer Dropout(\n", - " (_feature_shapes): Dict(\n", - " (city_id_list): TensorShape([64, None, 1])\n", - " (booker_country_list): TensorShape([64, None, 1])\n", - " (hotel_country_list): TensorShape([64, None, 1])\n", - " (weekday_checkin_list): TensorShape([64, None, 1])\n", - " (weekday_checkout_list): TensorShape([64, None, 1])\n", - " )\n", - " (_feature_dtypes): Dict(\n", - " (city_id_list): tf.int64\n", - " (booker_country_list): tf.int64\n", - " (hotel_country_list): tf.int64\n", - " (weekday_checkin_list): tf.int64\n", - " (weekday_checkout_list): tf.int64\n", - " )\n", - "), because it is not built.\n", - "WARNING:absl:Found untraced functions such as model_context_layer_call_fn, model_context_layer_call_and_return_conditional_losses, sequence_mask_random_layer_call_fn, sequence_mask_random_layer_call_and_return_conditional_losses, sequence_mask_last_layer_call_fn while saving (showing 5 of 108). These functions will not be directly callable after loading.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:tensorflow:Assets written to: /workspace/data/ensemble/1_predicttensorflowtriton/1/model.savedmodel/assets\n", - "/usr/local/lib/python3.8/dist-packages/merlin/models/tf/utils/tf_utils.py:101: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " config[key] = tf.keras.utils.serialize_keras_object(maybe_value)\n", - "/usr/local/lib/python3.8/dist-packages/merlin/models/tf/core/combinators.py:288: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " config[i] = tf.keras.utils.serialize_keras_object(layer)\n", - "/usr/local/lib/python3.8/dist-packages/keras/saving/legacy/saved_model/layer_serialization.py:134: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", - " return serialization.serialize_keras_object(obj)\n", - "/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py:120: UserWarning: The initializer TruncatedNormal is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" - ] - } - ], - "source": [ - "from merlin.systems.dag.ops.tensorflow import PredictTensorflow\n", - "from merlin.systems.dag.ensemble import Ensemble\n", - "from merlin.systems.dag.ops.workflow import TransformWorkflow\n", - "\n", - "inf_ops = wf.input_schema.column_names >> TransformWorkflow(wf) >> PredictTensorflow(model)\n", - "\n", - "ensemble = Ensemble(inf_ops, wf.input_schema)\n", - "ensemble.export(os.path.join(OUTPUT_DATA_DIR, 'ensemble'));" - ] - }, - { - "cell_type": "markdown", - "id": "5edc6046", - "metadata": {}, - "source": [ - "After we export the ensemble, we are ready to start the Triton Inference Server.\n", - "\n", - "The server is installed in Merlin Tensorflow and Merlin PyTorch containers. If you are not using one of our containers, then ensure it is installed in your environment. For more information, see the Triton Inference Server [documentation](https://github.com/triton-inference-server/server/blob/r22.03/README.md#documentation).\n", - "\n", - "You can start the server by running the following command:\n", - "\n", - "```tritonserver --model-repository={OUTPUT_DATA_DIR}/ensemble/```\n", - "\n", - "For the --model-repository argument, specify the same value as the `export_path` that you specified previously in the `ensemble.export` method.\n", - "\n", - "After you run the `tritonserver` command, wait until your terminal shows messages like the following example:\n", - "\n", - "I0414 18:29:50.741833 4067 grpc_server.cc:4421] Started GRPCInferenceService at 0.0.0.0:8001
\n", - "I0414 18:29:50.742197 4067 http_server.cc:3113] Started HTTPService at 0.0.0.0:8000
\n", - "I0414 18:29:50.783470 4067 http_server.cc:178] Started Metrics Service at 0.0.0.0:8002\n", - "\n", - "Let us now package our data for inference. We will send the first 4 rows of our validation data, which corresponds to a single trip. The data will be first processed by the `NVTabular` workflow and subsequentally passed to our transformer model for predicting. " - ] - }, - { - "cell_type": "markdown", - "id": "d83a304d", - "metadata": {}, - "source": [ - "Let us send the first 4 rows of our validation data to Triton. This will correspond to a single trip (all rows have the same `utrip_id`) with four stops." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "3cad9026", - "metadata": {}, - "outputs": [], - "source": [ - "from merlin.systems.triton import convert_df_to_triton_input\n", - "\n", - "validation_data = validation_set_dataset.compute()\n", - "inputs = convert_df_to_triton_input(wf.input_schema, validation_data.iloc[:4])" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "c508adce", - "metadata": {}, - "outputs": [], - "source": [ - "import tritonclient.grpc as grpcclient\n", - "\n", - "with grpcclient.InferenceServerClient(\"localhost:8001\") as client:\n", - " response = client.infer('executor_model', inputs)" - ] - }, - { - "cell_type": "markdown", - "id": "6d34eecf", - "metadata": {}, - "source": [ - "The response consists of logits coming from our model." - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "b3284691", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-2.8206294 , -1.3849059 , 1.9042726 , ..., 0.851537 ,\n", - " -2.4237087 , -0.73849726]], dtype=float32)" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response.as_numpy('city_id_list/categorical_output')" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "824d2b4f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1, 37203)" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "predictions = response.as_numpy('city_id_list/categorical_output')\n", - "predictions.shape" - ] - }, - { - "cell_type": "markdown", - "id": "fc5d415b", - "metadata": {}, - "source": [ - "The above values are logits output from the last layer of our model. They correspond in size to the cardinality of `city_id`, our target variable:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "29a8c0bd", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "37203" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cardinality = wf.output_schema['city_id_list'].properties['embedding_sizes']['cardinality']\n", - "cardinality" - ] - }, - { - "cell_type": "markdown", - "id": "3c54c30f", - "metadata": {}, - "source": [ - "## Summary" - ] - }, - { - "cell_type": "markdown", - "id": "709c07fb", - "metadata": {}, - "source": [ - "We have trained a transformer model for the next item prediction task using language model masking.\n", - "\n", - "For another session-based example that goes deeper into data preprocessing and that covers several advanced techniques (Weight Tying, Temperature Scaling) please see [Session-Based Next Item Prediction for Fashion E-Commerce](https://github.com/NVIDIA-Merlin/models/blob/t4rec_use_case/examples/usecases/ecommerce-session-based-next-item-prediction-for-fashion.ipynb). " - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/merlin/datasets/entertainment/music_streaming/schema.json b/merlin/datasets/entertainment/music_streaming/schema.json index cdef74879a..bfe3a14530 100644 --- a/merlin/datasets/entertainment/music_streaming/schema.json +++ b/merlin/datasets/entertainment/music_streaming/schema.json @@ -98,7 +98,8 @@ "annotation": { "tag": [ "categorical", - "user_id" + "user_id", + "user" ] } }, diff --git a/merlin/models/tf/__init__.py b/merlin/models/tf/__init__.py index 85708e2a4c..da62ef0a7e 100644 --- a/merlin/models/tf/__init__.py +++ b/merlin/models/tf/__init__.py @@ -155,6 +155,7 @@ HashedCross, HashedCrossAll, PrepareFeatures, + PrepareListFeatures, ToDense, ToOneHot, ToSparse, @@ -226,6 +227,7 @@ "MatrixFactorizationBlock", "QueryItemIdsEmbeddingsBlock", "PrepareFeatures", + "PrepareListFeatures", "ToSparse", "ToDense", "ToTarget", diff --git a/merlin/models/tf/core/encoder.py b/merlin/models/tf/core/encoder.py index 0dd47b5187..5833b2d674 100644 --- a/merlin/models/tf/core/encoder.py +++ b/merlin/models/tf/core/encoder.py @@ -27,10 +27,12 @@ from merlin.models.tf.core.prediction import TopKPrediction from merlin.models.tf.inputs.base import InputBlockV2 from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable +from merlin.models.tf.loader import Loader from merlin.models.tf.models.base import BaseModel, get_output_schema from merlin.models.tf.outputs.topk import TopKOutput from merlin.models.tf.transforms.features import PrepareFeatures from merlin.models.tf.utils import tf_utils +from merlin.models.tf.utils.batch_utils import TFModelEncode from merlin.schema import ColumnSchema, Schema, Tags @@ -83,7 +85,7 @@ def __init__( def encode( self, - dataset: merlin.io.Dataset, + dataset: Union[merlin.io.Dataset, Loader], index: Union[str, ColumnSchema, Schema, Tags], batch_size: int, **kwargs, @@ -92,7 +94,7 @@ def encode( Parameters ---------- - dataset: merlin.io.Dataset + dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader] The dataset to encode. index: Union[str, ColumnSchema, Schema, Tags] The index to use for encoding. @@ -126,7 +128,7 @@ def encode( def batch_predict( self, - dataset: merlin.io.Dataset, + dataset: Union[merlin.io.Dataset, Loader], batch_size: int, output_schema: Optional[Schema] = None, index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, @@ -136,8 +138,8 @@ def batch_predict( Parameters ---------- - dataset: merlin.io.Dataset - Dataset to predict on. + dataset: Union[merlin.io.Dataset, merlin.models.tf.loader.Loader] + Dataset or Loader to predict on. batch_size: int Batch size to use for prediction. @@ -160,24 +162,46 @@ def batch_predict( raise ValueError("Only one column can be used as index") index = index.first.name + dataset_schema = None if hasattr(dataset, "schema"): - if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)): + dataset_schema = dataset.schema + data_output_schema = dataset_schema + if isinstance(dataset, Loader): + data_output_schema = dataset.output_schema + if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)): raise ValueError( f"Model schema {self.schema.column_names} does not match dataset schema" - + f" {dataset.schema.column_names}" + + f" {data_output_schema.column_names}" ) + loader_transforms = None + if isinstance(dataset, Loader): + loader_transforms = dataset.transforms + batch_size = dataset.batch_size + dataset = dataset.dataset + # Check if merlin-dataset is passed if hasattr(dataset, "to_ddf"): dataset = dataset.to_ddf() - from merlin.models.tf.utils.batch_utils import TFModelEncode + model_encode = TFModelEncode( + self, + batch_size=batch_size, + loader_transforms=loader_transforms, + schema=dataset_schema, + **kwargs, + ) - model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs) encode_kwargs = {} if output_schema: encode_kwargs["filter_input_columns"] = output_schema.column_names - predictions = dataset.map_partitions(model_encode, **encode_kwargs) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = model_encode(dataset.head(), **encode_kwargs) + output_dtypes = sample_output.dtypes.to_dict() + + predictions = dataset.map_partitions(model_encode, meta=output_dtypes, **encode_kwargs) if index: predictions = predictions.set_index(index) @@ -577,7 +601,7 @@ def encode_candidates( def batch_predict( self, - dataset: merlin.io.Dataset, + dataset: Union[merlin.io.Dataset, Loader], batch_size: int, output_schema: Optional[Schema] = None, **kwargs, @@ -586,8 +610,8 @@ def batch_predict( Parameters ---------- - dataset : merlin.io.Dataset - Raw queries features dataset + dataset : Union[merlin.io.Dataset, merlin.models.tf.loader.Loader] + Raw queries features dataset or Loader batch_size : int The number of queries to process at each prediction step output_schema: Schema, optional @@ -600,20 +624,34 @@ def batch_predict( """ from merlin.models.tf.utils.batch_utils import TFModelEncode + loader_transforms = None + if isinstance(dataset, Loader): + loader_transforms = dataset.transforms + batch_size = dataset.batch_size + dataset = dataset.dataset + + dataset_schema = dataset.schema + dataset = dataset.to_ddf() + model_encode = TFModelEncode( model=self, batch_size=batch_size, + loader_transforms=loader_transforms, + schema=dataset_schema, output_names=TopKPrediction.output_names(self.k), **kwargs, ) - dataset = dataset.to_ddf() - encode_kwargs = {} if output_schema: encode_kwargs["filter_input_columns"] = output_schema.column_names - predictions = dataset.map_partitions(model_encode, **encode_kwargs) + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = model_encode(dataset.head(), **encode_kwargs) + output_dtypes = sample_output.dtypes.to_dict() + + predictions = dataset.map_partitions(model_encode, meta=output_dtypes, **encode_kwargs) return merlin.io.Dataset(predictions) diff --git a/merlin/models/tf/core/index.py b/merlin/models/tf/core/index.py index fc36f1a114..ba2632f70e 100644 --- a/merlin/models/tf/core/index.py +++ b/merlin/models/tf/core/index.py @@ -113,7 +113,15 @@ def get_candidates_dataset( model_encode = TFModelEncode(model=block, output_concat_func=np.concatenate) data = data.to_ddf() - embedding_ddf = data.map_partitions(model_encode, filter_input_columns=[id_column]) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = model_encode(data.head(), filter_input_columns=[id_column]) + output_dtypes = sample_output.dtypes.to_dict() + + embedding_ddf = data.map_partitions( + model_encode, meta=output_dtypes, filter_input_columns=[id_column] + ) embedding_df = embedding_ddf.compute(scheduler="synchronous") embedding_df.set_index(id_column, inplace=True) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index dd8e96a440..af0cb14c7e 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -1553,7 +1553,7 @@ def predict( return out def batch_predict( - self, dataset: merlin.io.Dataset, batch_size: int, **kwargs + self, dataset: Union[merlin.io.Dataset, Loader], batch_size: int, **kwargs ) -> merlin.io.Dataset: """Batched prediction using the Dask. Parameters @@ -1565,21 +1565,45 @@ def batch_predict( Returns merlin.io.Dataset ------- """ + dataset_schema = None if hasattr(dataset, "schema"): - if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)): + dataset_schema = dataset.schema + data_output_schema = dataset_schema + if isinstance(dataset, Loader): + data_output_schema = dataset.output_schema + + if not set(self.schema.column_names).issubset(set(data_output_schema.column_names)): raise ValueError( f"Model schema {self.schema.column_names} does not match dataset schema" - + f" {dataset.schema.column_names}" + + f" {data_output_schema.column_names}" ) + loader_transforms = None + if isinstance(dataset, Loader): + loader_transforms = dataset.transforms + batch_size = dataset.batch_size + dataset = dataset.dataset + # Check if merlin-dataset is passed if hasattr(dataset, "to_ddf"): dataset = dataset.to_ddf() from merlin.models.tf.utils.batch_utils import TFModelEncode - model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs) - predictions = dataset.map_partitions(model_encode) + model_encode = TFModelEncode( + self, + batch_size=batch_size, + loader_transforms=loader_transforms, + schema=dataset_schema, + **kwargs, + ) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = model_encode(dataset.head()) + output_dtypes = sample_output.dtypes.to_dict() + + predictions = dataset.map_partitions(model_encode, meta=output_dtypes) return merlin.io.Dataset(predictions) @@ -1774,7 +1798,7 @@ def build(self, input_shape=None): self.built = True - def call(self, inputs, targets=None, training=False, testing=False, output_context=False): + def call(self, inputs, targets=None, training=None, testing=None, output_context=None): """ Method for forward pass of the model. @@ -1796,6 +1820,9 @@ def call(self, inputs, targets=None, training=False, testing=False, output_conte Tensor or tuple of Tensor and ModelContext Output of the model, and optionally the context """ + training = training or False + testing = testing or False + output_context = output_context or False outputs = inputs features = self._prepare_features(inputs, targets=targets) if isinstance(features, tuple): @@ -2354,7 +2381,13 @@ def query_embeddings( get_user_emb = QueryEmbeddings(self, batch_size=batch_size) dataset = unique_rows_by_features(dataset, query_tag, query_id_tag).to_ddf() - embeddings = dataset.map_partitions(get_user_emb) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = get_user_emb(dataset.head()) + output_dtypes = sample_output.dtypes.to_dict() + + embeddings = dataset.map_partitions(get_user_emb, meta=output_dtypes) return merlin.io.Dataset(embeddings) @@ -2389,7 +2422,13 @@ def item_embeddings( get_item_emb = ItemEmbeddings(self, batch_size=batch_size) dataset = unique_rows_by_features(dataset, item_tag, item_id_tag).to_ddf() - embeddings = dataset.map_partitions(get_item_emb) + + # Processing a sample of the dataset with the model encoder + # to get the output dataframe dtypes + sample_output = get_item_emb(dataset.head()) + output_dtypes = sample_output.dtypes.to_dict() + + embeddings = dataset.map_partitions(get_item_emb, meta=output_dtypes) return merlin.io.Dataset(embeddings) @@ -2492,20 +2531,20 @@ def query_embeddings( def candidate_embeddings( self, - dataset: Optional[merlin.io.Dataset] = None, + data: Optional[Union[merlin.io.Dataset, Loader]] = None, index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, **kwargs, ) -> merlin.io.Dataset: if self.has_candidate_encoder: candidate = self.candidate_encoder - if dataset is not None and hasattr(candidate, "encode"): - return candidate.encode(dataset, index=index, **kwargs) + if data is not None and hasattr(candidate, "encode"): + return candidate.encode(data, index=index, **kwargs) if hasattr(candidate, "to_dataset"): return candidate.to_dataset(**kwargs) - return candidate.encode(dataset, index=index, **kwargs) + return candidate.encode(data, index=index, **kwargs) if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)): return self.last.to_dataset() diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index 312019cd92..b1090f6605 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -184,7 +184,10 @@ def compute_output_shape(self, input_shape): def get_config(self): """Returns the config of the layer as a Python dictionary.""" config = super().get_config() - config["target"] = self.target + target = self.target + if isinstance(target, ColumnSchema): + target = schema_utils.schema_to_tensorflow_metadata_json(Schema([target])) + config["target"] = target return config @@ -193,6 +196,10 @@ def from_config(cls, config): """Creates layer from its config. Returning the instance.""" config = tf_utils.maybe_deserialize_keras_objects(config, ["pre", "post", "aggregation"]) config["schema"] = schema_utils.tensorflow_metadata_json_to_schema(config["schema"]) + if config["target"].startswith("{"): # we have a schema + config["target"] = [ + col for col in schema_utils.tensorflow_metadata_json_to_schema(config["target"]) + ][0] schema = config.pop("schema") target = config.pop("target") return cls(schema, target, **config) diff --git a/merlin/models/tf/utils/batch_utils.py b/merlin/models/tf/utils/batch_utils.py index 9ffce6b649..2e83957eb3 100644 --- a/merlin/models/tf/utils/batch_utils.py +++ b/merlin/models/tf/utils/batch_utils.py @@ -8,8 +8,7 @@ from merlin.models.tf.core.base import Block from merlin.models.tf.loader import Loader from merlin.models.tf.models.base import Model, RetrievalModel, get_task_names_from_outputs -from merlin.models.utils.schema_utils import select_targets -from merlin.schema import Schema, Tags +from merlin.schema import Schema class ModelEncode: @@ -75,6 +74,7 @@ def __init__( block_load_func: tp.Optional[tp.Callable[[str], Block]] = None, schema: tp.Optional[Schema] = None, output_concat_func=None, + loader_transforms=None, ): save_path = save_path or tempfile.mkdtemp() model.save(save_path) @@ -96,7 +96,9 @@ def __init__( super().__init__( save_path, output_names, - data_iterator_func=data_iterator_func(self.schema, batch_size=batch_size), + data_iterator_func=data_iterator_func( + self.schema, batch_size=batch_size, loader_transforms=loader_transforms + ), model_load_func=model_load_func, model_encode_func=model_encode, output_concat_func=output_concat_func, @@ -173,21 +175,15 @@ def encode_output(output: tf.Tensor): return output.numpy() -def data_iterator_func(schema, batch_size: int = 512): +def data_iterator_func(schema, batch_size: int = 512, loader_transforms=None): import merlin.io.dataset - cat_cols = schema.select_by_tag(Tags.CATEGORICAL).excluding_by_tag(Tags.TARGET).column_names - cont_cols = schema.select_by_tag(Tags.CONTINUOUS).excluding_by_tag(Tags.TARGET).column_names - targets = select_targets(schema).column_names - def data_iterator(dataset): return Loader( - merlin.io.dataset.Dataset(dataset), + merlin.io.dataset.Dataset(dataset, schema=schema), batch_size=batch_size, - cat_names=cat_cols, - cont_names=cont_cols, - label_names=targets, shuffle=False, + transforms=loader_transforms, ) return data_iterator diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index 025c8ba0dc..4042784156 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -16,27 +16,54 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch, Sequence -from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock +from merlin.models.torch.block import ( + BatchBlock, + Block, + ParallelBlock, + ResidualBlock, + ShortcutBlock, + repeat, + repeat_parallel, + repeat_parallel_like, +) +from merlin.models.torch.blocks.attention import CrossAttentionBlock from merlin.models.torch.blocks.dlrm import DLRMBlock +from merlin.models.torch.blocks.experts import CGCBlock, MMOEBlock, PLEBlock from merlin.models.torch.blocks.mlp import MLPBlock +from merlin.models.torch.functional import map, walk from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys -from merlin.models.torch.inputs.tabular import TabularInputBlock -from merlin.models.torch.models.base import Model -from merlin.models.torch.models.ranking import DLRMModel +from merlin.models.torch.inputs.tabular import TabularInputBlock, stack_context +from merlin.models.torch.models.base import Model, MultiLoader +from merlin.models.torch.models.ranking import DCNModel, DLRMModel from merlin.models.torch.outputs.base import ModelOutput -from merlin.models.torch.outputs.classification import BinaryOutput +from merlin.models.torch.outputs.classification import ( + BinaryOutput, + CategoricalOutput, + CategoricalTarget, + EmbeddingTablePrediction, +) from merlin.models.torch.outputs.regression import RegressionOutput from merlin.models.torch.outputs.tabular import TabularOutputBlock +from merlin.models.torch.predict import DaskEncoder, DaskPredictor, EncoderBlock from merlin.models.torch.router import RouterBlock from merlin.models.torch.transforms.agg import Concat, Stack +from merlin.models.torch.transforms.sequences import BroadcastToSequence, TabularPadding + +input_schema = schema.input_schema +output_schema = schema.output_schema +target_schema = schema.target_schema +feature_schema = schema.feature_schema __all__ = [ "Batch", "BinaryOutput", "Block", + "BatchBlock", + "DLRMBlock", "MLPBlock", "Model", + "MultiLoader", "EmbeddingTable", "EmbeddingTables", "ParallelBlock", @@ -55,6 +82,29 @@ "Concat", "Stack", "schema", + "repeat", + "repeat_parallel", + "repeat_parallel_like", + "CategoricalOutput", + "CategoricalTarget", + "EmbeddingTablePrediction", + "input_schema", + "output_schema", + "feature_schema", + "target_schema", "DLRMBlock", "DLRMModel", + "DCNModel", + "MMOEBlock", + "PLEBlock", + "CGCBlock", + "TabularPadding", + "BroadcastToSequence", + "EncoderBlock", + "DaskEncoder", + "DaskPredictor", + "stack_context", + "CrossAttentionBlock", + "map", + "walk", ] diff --git a/merlin/models/torch/batch.py b/merlin/models/torch/batch.py index 72d813c00c..610cef03ce 100644 --- a/merlin/models/torch/batch.py +++ b/merlin/models/torch/batch.py @@ -20,8 +20,6 @@ from merlin.dataloader.torch import Loader from merlin.io import Dataset -from merlin.models.torch import schema -from merlin.schema import Schema @torch.jit.script @@ -67,6 +65,9 @@ def __init__( def __contains__(self, name: str) -> bool: return name in self.lengths + def __bool__(self) -> bool: + return bool(self.lengths) + def length(self, name: str = "default") -> torch.Tensor: """Retrieves a length tensor from a sequence by name. @@ -119,6 +120,16 @@ def device(self) -> torch.device: raise ValueError("Sequence is empty") + def flatten_to_dict(self) -> Dict[str, torch.Tensor]: + outputs: Dict[str, torch.Tensor] = {} + for key, value in self.lengths.items(): + outputs["lengths." + key] = value + + for key, value in self.masks.items(): + outputs["masks." + key] = value + + return outputs + @torch.jit.script class Batch: @@ -166,7 +177,12 @@ def __init__( else: raise ValueError("Targets must be a tensor or a dictionary of tensors") self.targets: Dict[str, torch.Tensor] = _targets - self.sequences: Optional[Sequence] = sequences + if torch.jit.isinstance(sequences, Sequence): + _sequences = Sequence(sequences.lengths, sequences.masks) + else: + _masks: Dict[str, torch.Tensor] = {} + _sequences = Sequence(_masks) + self.sequences: Sequence = _sequences @staticmethod @torch.jit.ignore @@ -279,6 +295,118 @@ def target(self, name: str = "default") -> torch.Tensor: raise ValueError("Batch has multiple target, please specify a target name") + def inputs(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if len(self.features) == 1 and "default" in self.features: + return self.features["default"] + + return self.features + + def flatten_as_dict(self, inputs: Optional["Batch"]) -> Dict[str, torch.Tensor]: + """ + Flatten features, targets, and sequences into a dictionary of tensors. + + Each key should be prefixed with "features.", "targets.", "masks." or "lengths." + + If inputs is provided, it includes all keys that are present in both self and inputs, + with the value from self being used when a key is present in both. + + Parameters + ---------- + inputs : Batch, optional + Another Batch object to include in the flattening process. The keys from the input + batch are also added with a prefix of "inputs.", by default None + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary containing all the flattened features, targets, and sequences. + """ + flat_dict: Dict[str, torch.Tensor] = self._flatten() + dummy_tensor = torch.tensor(0) + + if torch.jit.isinstance(inputs, Batch) and inputs is not self: + _input_dict: Dict[str, torch.Tensor] = inputs._flatten() + for key, val in _input_dict.items(): + flat_dict["inputs." + key] = dummy_tensor + if ( + not key.endswith("__values") + and not key.endswith("__offsets") + and key not in flat_dict + ): + flat_dict[key] = val + + return flat_dict + + def _flatten(self) -> Dict[str, torch.Tensor]: + """ + Helper function to flatten features, targets, and sequences of the current batch. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary containing all the flattened features, targets, and sequences. + """ + flat_dict = {} + + for key, value in self.features.items(): + flat_dict["features." + key] = value + + for key, value in self.targets.items(): + flat_dict["targets." + key] = value + + _sequence_dict = self.sequences.flatten_to_dict() + if _sequence_dict: + flat_dict.update(_sequence_dict) + + return flat_dict + + @staticmethod + def from_partial_dict(input: Dict[str, torch.Tensor], batch: "Batch") -> "Batch": + """ + The input param comes from flatten_as_dict. + + It could be that certain keys are missing from the input dict, in which case + we should use the values from the batch object. + """ + features = {} + targets = {} + lengths = {} + masks = {} + + for key, value in input.items(): + key_split = key.split(".") + if key_split[0] == "features": + features[key_split[1]] = value + elif key_split[0] == "targets": + targets[key_split[1]] = value + elif key_split[0] == "lengths": + lengths[key_split[1]] = value + elif key_split[0] == "masks": + masks[key_split[1]] = value + + # If a key is missing in the input dict and was in the inputs of flatten_as_dict, + # use the value from the batch object + for key in batch.features: + if f"inputs.features.{key}" not in input: + features[key] = batch.features[key] + for key in batch.targets: + if f"inputs.targets.{key}" not in input: + targets[key] = batch.targets[key] + if batch.sequences is not None: + for key in batch.sequences.lengths: + if f"inputs.lengths.{key}" not in input: + lengths[key] = batch.sequences.lengths[key] + for key in batch.sequences.masks: + if f"inputs.masks.{key}" not in input: + masks[key] = batch.sequences.masks[key] + + if lengths or masks: + sequences = Sequence(lengths, masks) + else: + sequences = None + + return Batch(features, targets, sequences) + def __bool__(self) -> bool: return bool(self.features) @@ -373,12 +501,3 @@ def sample_features( """ return sample_batch(data, batch_size, shuffle).features - - -@schema.output.register_tensor(Batch) -def _(input): - output_schema = Schema() - output_schema += schema.output.tensors(input.features) - output_schema += schema.output.tensors(input.targets) - - return output_schema diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index 42dede5b9b..feeaebe0c7 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -17,19 +17,27 @@ import inspect import textwrap from copy import deepcopy -from typing import Dict, Optional, Tuple, TypeVar, Union +from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union, runtime_checkable import torch from torch import nn from merlin.models.torch import schema -from merlin.models.torch.batch import Batch +from merlin.models.torch.batch import Batch, Sequence from merlin.models.torch.container import BlockContainer, BlockContainerDict from merlin.models.torch.registry import registry from merlin.models.torch.utils.traversal_utils import TraversableMixin from merlin.models.utils.registry import RegistryMixin from merlin.schema import Schema +TensorOrDict = Union[torch.Tensor, Dict[str, torch.Tensor]] + + +@runtime_checkable +class HasKeys(Protocol): + def keys(self): + ... + class Block(BlockContainer, RegistryMixin, TraversableMixin): """A base-class that calls it's modules sequentially. @@ -47,9 +55,7 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin): def __init__(self, *module: nn.Module, name: Optional[str] = None): super().__init__(*module, name=name) - def forward( - self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None - ): + def forward(self, inputs: TensorOrDict, batch: Optional[Batch] = None): """ Forward pass through the block. Applies each contained module sequentially on the input. @@ -87,15 +93,13 @@ def repeat(self, n: int = 1, name=None) -> "Block": Block The new block created by repeating the current block `n` times. """ - if not isinstance(n, int): - raise TypeError("n must be an integer") + return repeat(self, n, name=name) - if n < 1: - raise ValueError("n must be greater than 0") + def repeat_parallel(self, n: int = 1, name=None) -> "ParallelBlock": + return repeat_parallel(self, n, name=name) - repeats = [self.copy() for _ in range(n - 1)] - - return Block(self, *repeats, name=name) + def repeat_parallel_like(self, like: HasKeys, agg=None) -> "ParallelBlock": + return repeat_parallel_like(self, like, agg=agg) def copy(self) -> "Block": """ @@ -163,9 +167,7 @@ def __init__(self, *inputs: Union[nn.Module, Dict[str, nn.Module]]): self.branches = branches self.post = post - def forward( - self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None - ): + def forward(self, inputs: TensorOrDict, batch: Optional[Batch] = None): """Forward pass through the block. The steps are as follows: @@ -206,6 +208,9 @@ def forward( raise RuntimeError(f"Duplicate output name: {key}") outputs.update(branch_out) + elif torch.jit.isinstance(branch_out, Batch): + _flattened_batch: Dict[str, torch.Tensor] = branch_out.flatten_as_dict(batch) + outputs.update(_flattened_batch) else: raise TypeError( f"Branch output must be a tensor or a dictionary of tensors. Got {_inputs}" @@ -342,6 +347,9 @@ def replace(self, pre=None, branches=None, post=None) -> "ParallelBlock": return output + def keys(self): + return self.branches.keys() + def leaf(self) -> nn.Module: if self.pre: raise ValueError("Cannot call leaf() on a ParallelBlock with a pre-processing stage") @@ -567,6 +575,168 @@ def forward( return to_return +class BatchBlock(Block): + """ + Class to use for `Batch` creation. We can use this class to create a `Batch` from + - a tensor or a dictionary of tensors + - a `Batch` object + - a tuple of features and targets + + Example usage:: + >>> batch = mm.BatchBlock()(torch.ones(1, 1)) + >>> batch + Batch(features={"default": tensor([[1.]])}) + + """ + + def forward( + self, + inputs: Union[Batch, TensorOrDict], + targets: Optional[TensorOrDict] = None, + sequences: Optional[Sequence] = None, + batch: Optional[Batch] = None, + ) -> Batch: + """ + Perform forward propagation on either a Batch object, or on inputs, targets and sequences + which are then packed into a Batch. + + Parameters + ---------- + inputs : Union[Batch, TensorOrDict] + Either a Batch object or a dictionary of tensors. + + targets : Optional[TensorOrDict], optional + A dictionary of tensors, by default None + + sequences : Optional[Sequence], optional + A sequence of tensors, by default None + + batch : Optional[Batch], optional + A Batch object, by default None + + Returns + ------- + Batch + The resulting Batch after forward propagation. + """ + if torch.jit.isinstance(batch, Batch): + return self.forward_batch(batch) + if torch.jit.isinstance(inputs, Batch): + return self.forward_batch(inputs) + + return self.forward_batch(Batch(inputs, targets, sequences)) + + def forward_batch(self, batch: Batch) -> Batch: + """ + Perform forward propagation on a Batch object. + + For each module in the block, this method performs a forward pass with the + current output features and the original batch object. + - If a module returns a Batch object, this becomes the new output. + - If a module returns a dictionary of tensors, a new Batch object is created + from this dictionary and the original batch object. The new Batch replaces + the current output. This is useful when a module only modifies a subset of + the batch. + + + Parameters + ---------- + batch : Batch + A Batch object. + + Returns + ------- + Batch + The resulting Batch after forward propagation. + + Raises + ------ + RuntimeError + When the output of a module is neither a Batch object nor a dictionary of tensors. + """ + output = batch + for module in self.values: + module_out = module(output.features, batch=output) + if torch.jit.isinstance(module_out, Batch): + output = module_out + elif torch.jit.isinstance(module_out, Dict[str, torch.Tensor]): + output = Batch.from_partial_dict(module_out, batch) + else: + raise RuntimeError("Module must return a Batch or a dict of tensors") + + return output + + +def _validate_n(n: int) -> None: + if not isinstance(n, int): + raise TypeError("n must be an integer") + + if n < 1: + raise ValueError("n must be greater than 0") + + +def repeat(module: nn.Module, n: int = 1, name=None) -> Block: + """ + Creates a new block by repeating the current block `n` times. + Each repetition is a deep copy of the current block. + + Parameters + ---------- + module: nn.Module + The module to be repeated. + n : int + The number of times to repeat the current block. + name : Optional[str], default = None + The name for the new block. If None, no name is assigned. + + Returns + ------- + Block + The new block created by repeating the current block `n` times. + """ + _validate_n(n) + + repeats = [module.copy() if hasattr(module, "copy") else deepcopy(module) for _ in range(n - 1)] + + return Block(module, *repeats, name=name) + + +def repeat_parallel(module: nn.Module, n: int = 1, agg=None) -> ParallelBlock: + _validate_n(n) + + branches = {"0": module} + branches.update( + {str(n): module.copy() if hasattr(module, "copy") else deepcopy(module) for n in range(n)} + ) + + output = ParallelBlock(branches) + if agg: + output.append(Block.parse(agg)) + + return output + + +def repeat_parallel_like(module: nn.Module, like: HasKeys, agg=None) -> ParallelBlock: + branches = {} + + if isinstance(like, Schema): + keys = like.column_names + else: + keys = list(like.keys()) + + for i, key in enumerate(keys): + if i == 0: + branches[str(key)] = module + else: + branches[str(key)] = module.copy() if hasattr(module, "copy") else deepcopy(module) + + output = ParallelBlock(branches) + if agg: + output.append(Block.parse(agg)) + + return output + + def get_pre(module: nn.Module) -> BlockContainer: if hasattr(module, "pre"): return module.pre @@ -588,31 +758,31 @@ def set_pre(module: nn.Module, pre: BlockContainer): return set_pre(module[0], pre) -@schema.input.register(BlockContainer) +@schema.input_schema.register(BlockContainer) def _(module: BlockContainer, input: Schema): - return schema.input(module[0], input) if module else input + return schema.input_schema(module[0], input) if module else input -@schema.input.register(ParallelBlock) +@schema.input_schema.register(ParallelBlock) def _(module: ParallelBlock, input: Schema): if module.pre: - return schema.input(module.pre) + return schema.input_schema(module.pre) out_schema = Schema() for branch in module.branches.values(): - out_schema += schema.input(branch, input) + out_schema += schema.input_schema(branch, input) return out_schema -@schema.output.register(ParallelBlock) +@schema.output_schema.register(ParallelBlock) def _(module: ParallelBlock, input: Schema): if module.post: - return schema.output(module.post, input) + return schema.output_schema(module.post, input) output = Schema() for name, branch in module.branches.items(): - branch_schema = schema.output(branch, input) + branch_schema = schema.output_schema(branch, input) if len(branch_schema) == 1 and branch_schema.first.name == "output": branch_schema = Schema([branch_schema.first.with_name(name)]) @@ -622,9 +792,9 @@ def _(module: ParallelBlock, input: Schema): return output -@schema.output.register(BlockContainer) +@schema.output_schema.register(BlockContainer) def _(module: BlockContainer, input: Schema): - return schema.output(module[-1], input) if module else input + return schema.output_schema(module[-1], input) if module else input BlockT = TypeVar("BlockT", bound=BlockContainer) @@ -720,13 +890,13 @@ def _extract_block(main, selection, route, name=None): if isinstance(main, ParallelBlock): return _extract_parallel(main, selection, route=route, name=name) - main_schema = schema.input(main) - route_schema = schema.input(route) + main_schema = schema.input_schema(main) + route_schema = schema.input_schema(route) if main_schema == route_schema: from merlin.models.torch.inputs.select import SelectFeatures - out_schema = schema.output(main, main_schema) + out_schema = schema.output_schema(main, main_schema) if len(out_schema) == 1 and out_schema.first.name == "output": out_schema = Schema([out_schema.first.with_name(name)]) diff --git a/merlin/models/torch/blocks/attention.py b/merlin/models/torch/blocks/attention.py new file mode 100644 index 0000000000..6a64fa1297 --- /dev/null +++ b/merlin/models/torch/blocks/attention.py @@ -0,0 +1,168 @@ +from copy import deepcopy +from typing import Dict, Optional, Union + +import torch +from torch import nn + +from merlin.models.torch.batch import Batch +from merlin.models.torch.block import Block + + +class CrossAttentionBlock(Block): + """ + Cross Attention Block module which performs a multihead attention operation + on a provided context and sequence. + + Note this block assumes that the input and output tensors are provided as + (batch, seq, feature). When using modules provided in PyTorch, e.g., + ``torch.nn.MultiheadAttention``, the ``batch_first`` parameter should be + set to True to match the shape. + + Example usage + ------------- + + >>> cross = CrossAttentionBlock( + ... attention=nn.MultiheadAttention(10, 2, batch_first=True), + ... key="context", + ... seq_key="sequence", + ... ) + >>> input_dict = { + ... "context": torch.randn(1, 2, 10), + ... "sequence": torch.randn(1, 6, 10)} + ... } + >>> cross(input_dict) + + Parameters + ---------- + module : nn.Module + Variable length input module list. + attention : nn.MultiheadAttention, optional + Predefined multihead attention module. If not provided, it's inferred from the first module. + name : str, optional + Name for the block. + key : str, optional + Key for the context tensor in the input dictionary. + seq_key : str, optional + Key for the sequence tensor in the input dictionary. + """ + + def __init__( + self, + *module: nn.Module, + attention: Optional[nn.MultiheadAttention] = None, + name: str = None, + key: str = "context", + seq_key: Optional[str] = None, + ): + super().__init__(*module, name=name) + + self.key = key + self.seq_key = seq_key + if attention is None: + if not ( + hasattr(module[0], "d_model") + and hasattr(module[0], "nhead") + and hasattr(module[0], "dropout") + ): + raise ValueError("Attention module not provided and cannot be inferred from module") + + # Try to infer from module + cross_attention = nn.MultiheadAttention( + module[0].d_model, module[0].nhead, module[0].dropout + ) + else: + cross_attention = attention + + self.cross_attention = nn.ModuleList([cross_attention]) + if len(module) > 1: + for m in module: + self.cross_attention.append( + m.copy() if hasattr(m, "copy") else deepcopy(cross_attention) + ) + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: + """ + Perform forward pass of the CrossAttentionBlock. + + Parameters + ---------- + inputs : Union[torch.Tensor, Dict[str, torch.Tensor]] + Dictionary containing the input tensors. + batch : Optional[Batch] + Optional batch information for the forward pass. + + Returns + ------- + torch.Tensor + Output tensor after the multihead attention operation. + + Raises + ------ + ValueError + If the input is a torch.Tensor instead of a dictionary. + """ + + if isinstance(inputs, torch.Tensor): + raise ValueError("CrossAttentionBlock requires a dictionary input") + + context, sequence = self.get_context(inputs), self.get_seq(inputs) + + for module, attention in zip(self.values, self.cross_attention): + sequence, _ = attention(sequence, context, context) + sequence = module(sequence, batch=batch) + + return sequence + + def get_context(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Retrieve the context tensor from the input dictionary using the key. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Input dictionary containing the tensors. + + Returns + ------- + torch.Tensor + The context tensor. + """ + return x[self.key] + + def get_seq(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Retrieve the sequence tensor from the input dictionary using the key. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Input dictionary containing the tensors. + + Returns + ------- + torch.Tensor + The sequence tensor. + + Raises + ------ + RuntimeError + If the seq_key is not found in the input dictionary or if the dictionary has more + than 2 keys and seq_key is not defined. + """ + if self.seq_key is None: + if len(x) == 2: + for key in x.keys(): + if key != self.key: + return x[key] + else: + raise RuntimeError( + "Please set seq_key for when more than 2 keys are present ", + f"in the input dictionary, got: {x}.", + ) + + if self.seq_key not in x: + raise RuntimeError(f"Could not find {self.seq_key} in input dictionary, got: {x}.") + + return x[self.seq_key] diff --git a/merlin/models/torch/blocks/cross.py b/merlin/models/torch/blocks/cross.py index ebfa215bcf..c395fc8440 100644 --- a/merlin/models/torch/blocks/cross.py +++ b/merlin/models/torch/blocks/cross.py @@ -5,6 +5,7 @@ from torch import nn from torch.nn.modules.lazy import LazyModuleMixin +from merlin.models.torch.batch import Batch from merlin.models.torch.block import Block from merlin.models.torch.transforms.agg import Concat from merlin.models.utils.doc_utils import docstring_parameter @@ -127,7 +128,9 @@ def with_low_rank(cls, depth: int, low_rank: nn.Module) -> "CrossBlock": return cls(*(Block(deepcopy(low_rank), *block) for block in cls.with_depth(depth))) - def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: """Forward-pass of the cross-block. Parameters diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py index 3b638ada08..ac31f13b7f 100644 --- a/merlin/models/torch/blocks/dlrm.py +++ b/merlin/models/torch/blocks/dlrm.py @@ -42,13 +42,10 @@ class DLRMInputBlock(TabularInputBlock): """ - def __init__(self, schema: Schema, dim: int, bottom_block: Block): + def __init__(self, schema: Optional[Schema], dim: int, bottom_block: Block): super().__init__(schema) self.add_route(Tags.CATEGORICAL, EmbeddingTables(dim, seq_combiner="mean")) - self.add_route(Tags.CONTINUOUS, bottom_block) - - if "categorical" not in self: - raise ValueError("DLRMInputBlock must have a categorical input") + self.add_route(Tags.CONTINUOUS, bottom_block, required=False) @docstring_parameter(dlrm_reference=_DLRM_REF) @@ -117,7 +114,7 @@ class DLRMBlock(Block): Parameters ---------- schema : Schema, optional - The schema to use for selection. Default is None. + The schema to use for selection. dim : int The dimensionality of the output vectors. bottom_block : Block @@ -139,7 +136,7 @@ class DLRMBlock(Block): def __init__( self, - schema: Schema, + schema: Optional[Schema], dim: int, bottom_block: Block, top_block: Optional[Block] = None, diff --git a/merlin/models/torch/blocks/experts.py b/merlin/models/torch/blocks/experts.py new file mode 100644 index 0000000000..627b96a1fd --- /dev/null +++ b/merlin/models/torch/blocks/experts.py @@ -0,0 +1,289 @@ +import textwrap +from functools import partial +from typing import Dict, Optional, Union + +import torch +from torch import nn + +from merlin.models.torch.batch import Batch +from merlin.models.torch.block import ( + Block, + ParallelBlock, + ShortcutBlock, + repeat_parallel, + repeat_parallel_like, +) +from merlin.models.torch.transforms.agg import Concat, Stack +from merlin.models.utils.doc_utils import docstring_parameter + +_PLE_REFERENCE = """ + References + ---------- + .. [1] Tang, Hongyan, et al. "Progressive layered extraction (ple): A novel multi-task + learning (mtl) model for personalized recommendations." + Fourteenth ACM Conference on Recommender Systems. 2020. +""" + + +class MMOEBlock(Block): + """ + Multi-gate Mixture-of-Experts (MMoE) Block introduced in [1]. + + The MMoE model builds upon the concept of an expert model by using a mixture + of experts for decision making. Each expert contributes independently to the + final decision, allowing for increased model complexity and performance + in multi-task learning scenarios. + + Example usage for multi-task learning:: + >>> outputs = mm.TabularOutputBlock(schema, init="defaults") + >>> mmoe = mm.MMOEBlock( + expert=mm.MLPBlock([5]), + num_experts=2, + outputs=outputs, + ) + >>> outputs.prepend_for_each(mm.MLPBlock([64])) # Add task-towers + >>> outputs.prepend(mmoe) + + + References + ---------- + [1] Ma, Jiaqi, et al. "Modeling task relationships in multi-task learning with + multi-gate mixture-of-experts." Proceedings of the 24th ACM SIGKDD international + conference on knowledge discovery & data mining. 2018. + + Parameters + ---------- + expert : nn.Module + The base expert model that serves as the foundation for the MMoE structure. + num_experts : int + The total number of experts in the MMoE model. Each expert operates independently + in the decision-making process. + outputs : Optional[ParallelBlock] + The output block of the model. + If it is an instance of ParallelBlock, the block is repeated for each expert. + Otherwise, a single ExpertGateBlock is used. + """ + + def __init__( + self, expert: nn.Module, num_experts: int, outputs: Optional[ParallelBlock] = None + ): + experts = repeat_parallel(expert, num_experts, agg=Stack(dim=1)) + super().__init__(ShortcutBlock(experts, output_name="experts")) + if isinstance(outputs, ParallelBlock): + self.append(repeat_parallel_like(ExpertGateBlock(num_experts), outputs)) + else: + self.append(ExpertGateBlock(num_experts)) + + +@docstring_parameter(ple_reference=_PLE_REFERENCE) +class PLEBlock(Block): + """ + Progressive Layered Extraction (PLE) Block proposed in [1]. + + The PLE model enhances the architecture of a typical expert model by organizing + shared and task-specific experts in a layered format. This layered structure + allows the extraction of increasingly complex features at each level and can + improve performance in multi-task settings. + + Example usage for multi-task learning:: + >>> outputs = mm.TabularOutputBlock(schema, init="defaults") + >>> ple = mm.PLEBlock( + expert=mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + depth=2, + outputs=outputs, + ) + >>> outputs.prepend_for_each(mm.MLPBlock([64])) # Add task-towers + >>> outputs.prepend(ple) + + {ple_reference} + + Parameters + ---------- + expert : nn.Module + The base expert model that forms the basis of the PLE structure. + num_shared_experts : int + The total count of shared experts. These experts contribute to the + decision process in all tasks. + num_task_experts : int + The total count of task-specific experts. These experts contribute + only to their specific tasks. + depth : int + The depth of the layered structure. Each layer comprises a set of experts + and the depth determines the number of such layers. + outputs : ParallelBlock + The output block, which encapsulates the final output from the model. + """ + + def __init__( + self, + expert: nn.Module, + *, + num_shared_experts: int, + num_task_experts: int, + depth: int, + outputs: ParallelBlock, + ): + cgc_kwargs = { + "expert": expert, + "num_shared_experts": num_shared_experts, + "num_task_experts": num_task_experts, + "outputs": outputs, + } + super().__init__(*CGCBlock(shared_gate=True, **cgc_kwargs).repeat(depth - 1)) + self.append(CGCBlock(**cgc_kwargs)) + + +class CGCBlock(Block): + """ + Implements the Customized Gate Control (CGC) proposed in [1]. + + The CGC model extends the capability of a typical expert model by introducing + shared and task-specific experts, thereby customizing the gating control per task, + which may lead to improved performance in multi-task settings. + + {ple_reference} + + Parameters + ---------- + expert : nn.Module + The base expert model that is used as the foundation for the gating mechanism. + num_shared_experts : int + The total count of shared experts. These experts contribute to the decision + process in all tasks. + num_task_experts : int + The total count of task-specific experts. These experts contribute only + to their specific tasks. + outputs : ParallelBlock + The output block, which encapsulates the final output from the model. + shared_gate : bool, optional + Defines whether a shared gate is used across all tasks or not. + If set to True, a shared gate is used. Defaults to False. + """ + + def __init__( + self, + expert: nn.Module, + *, + num_shared_experts: int, + num_task_experts: int, + outputs: ParallelBlock, + shared_gate: bool = False, + ): + shared_experts = repeat_parallel(expert, num_shared_experts, agg=Stack(dim=1)) + expert_shortcut = partial(ShortcutBlock, output_name="experts") + super().__init__(expert_shortcut(shared_experts)) + + gates = ParallelBlock() + for name in outputs.branches: + gates.branches[name] = PLEExpertGateBlock( + num_shared_experts + num_task_experts, + task_experts=repeat_parallel(expert, num_task_experts, agg=Stack(dim=1)), + name=name, + ) + if shared_gate: + gates.branches["experts"] = expert_shortcut( + ExpertGateBlock(num_shared_experts), propagate_shortcut=True + ) + + self.append(gates) + + +class ExpertGateBlock(Block): + """Expert Gate Block. + + # TODO: Add initialize_from_schema to remove the need to pass in num_experts + + Parameters + ---------- + num_experts : int + The number of experts used. + """ + + def __init__(self, num_experts: int): + super().__init__(GateBlock(num_experts)) + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: + if torch.jit.isinstance(inputs, torch.Tensor): + raise RuntimeError("ExpertGateBlock requires a dictionary input") + + experts = inputs["experts"] + outputs = inputs["shortcut"] + for module in self.values: + outputs = module(outputs, batch=batch) + + # return torch.sum(experts * outputs, dim=1, keepdim=False) + gated = outputs.expand_as(experts) + + # Multiply and sum along the experts dimension + return (experts * gated).sum(dim=1) + + +class PLEExpertGateBlock(Block): + """ + Progressive Layered Extraction (PLE) Expert Gate Block. + + Parameters + ---------- + num_experts : int + The number of experts used. + task_experts : nn.Module + The expert module. + name : str + The name of the task. + """ + + def __init__(self, num_experts: int, task_experts: nn.Module, name: str): + super().__init__(ExpertGateBlock(num_experts), name=f"PLEExpertGateBlock[{name}]") + self.stack = Stack(dim=1) + self.concat = Concat(dim=1) + self.task_experts = task_experts + self.task_name = name + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: + if torch.jit.isinstance(inputs, torch.Tensor): + raise RuntimeError("ExpertGateBlock requires a dictionary input") + + task_experts = self.task_experts(inputs["shortcut"], batch=batch) + if torch.jit.isinstance(task_experts, torch.Tensor): + _task = task_experts + elif torch.jit.isinstance(task_experts, Dict[str, torch.Tensor]): + _task = self.stack(task_experts) + else: + raise RuntimeError("PLEExpertGateBlock requires a dictionary input") + experts = self.concat({"experts": inputs["experts"], "task_experts": _task}) + task = inputs[self.task_name] if self.task_name in inputs else inputs["shortcut"] + + outputs = {"experts": experts, "shortcut": task} + for block in self.values: + outputs = block(outputs, batch=batch) + + return outputs + + def __repr__(self) -> str: + indent_str = " " + output = textwrap.indent("\n(task_experts): " + repr(self.task_experts), indent_str) + output += textwrap.indent("\n(gate): " + repr(self.values[0]), indent_str) + + return f"{self._get_name()}({output}\n)" + + +class SoftmaxGate(nn.Module): + """Softmax Gate for gating mechanism.""" + + def forward(self, gate_logits): + return torch.softmax(gate_logits, dim=-1).unsqueeze(-1) + + +class GateBlock(Block): + """Gate Block for gating mechanism.""" + + def __init__(self, num_experts: int): + super().__init__() + self.append(nn.LazyLinear(num_experts)) + self.append(SoftmaxGate()) diff --git a/merlin/models/torch/blocks/mlp.py b/merlin/models/torch/blocks/mlp.py index e7e9c1334d..8038dc89f7 100644 --- a/merlin/models/torch/blocks/mlp.py +++ b/merlin/models/torch/blocks/mlp.py @@ -4,7 +4,7 @@ from torch import nn from merlin.models.torch.block import Block -from merlin.models.torch.schema import Schema, output +from merlin.models.torch.schema import Schema, output_schema from merlin.models.torch.transforms.agg import Concat, MaybeAgg @@ -84,8 +84,8 @@ def __init__( super().__init__(*modules) -@output.register(nn.LazyLinear) -@output.register(nn.Linear) -@output.register(MLPBlock) -def _output_schema_block(module: nn.LazyLinear, input: Schema): - return output.tensors(torch.ones((1, module.out_features), dtype=float)) +@output_schema.register(nn.LazyLinear) +@output_schema.register(nn.Linear) +@output_schema.register(MLPBlock) +def _output_schema_block(module: nn.LazyLinear, inputs: Schema): + return output_schema.tensors(torch.ones((1, module.out_features), dtype=float)) diff --git a/merlin/models/torch/container.py b/merlin/models/torch/container.py index e289c694fa..8241ee33f4 100644 --- a/merlin/models/torch/container.py +++ b/merlin/models/torch/container.py @@ -16,15 +16,16 @@ from copy import deepcopy from functools import reduce -from typing import Dict, Iterator, Optional, Union +from typing import Dict, Iterable, Iterator, Optional, Sequence, Union from torch import nn from torch._jit_internal import _copy_to_script_wrapper +from merlin.models.torch.functional import ContainerMixin, _TModule from merlin.models.torch.utils import torchscript_utils -class BlockContainer(nn.Module): +class BlockContainer(nn.Module, Iterable[_TModule], ContainerMixin): """A container class for PyTorch `nn.Module` that allows for manipulation and traversal of multiple sub-modules as if they were a list. The modules are automatically wrapped in a TorchScriptWrapper for TorchScript compatibility. @@ -62,6 +63,23 @@ def append(self, module: nn.Module): return self + def extend(self, sequence: Sequence[nn.Module]): + """Extends the list by appending elements from the iterable. + + Parameters + ---------- + module : nn.Module + The PyTorch module to be appended. + + Returns + ------- + self + """ + for m in sequence: + self.append(m) + + return self + def prepend(self, module: nn.Module): """Prepends a given module to the beginning of the list. @@ -140,12 +158,6 @@ def __setitem__(self, idx: int, module: nn.Module) -> None: def __delitem__(self, idx: Union[slice, int]) -> None: self.values.__delitem__(idx) - def __add__(self, other) -> "BlockContainer": - for module in other: - self.append(module) - - return self - def __bool__(self) -> bool: return bool(self.values) @@ -187,7 +199,7 @@ def __init__( self, *inputs: Union[nn.Module, Dict[str, nn.Module]], name: Optional[str] = None, - block_cls=BlockContainer + block_cls=BlockContainer, ) -> None: if not inputs: inputs = [{}] diff --git a/merlin/models/torch/functional.py b/merlin/models/torch/functional.py new file mode 100644 index 0000000000..bc7fa6f14e --- /dev/null +++ b/merlin/models/torch/functional.py @@ -0,0 +1,509 @@ +import builtins +import inspect +from copy import deepcopy +from functools import wraps +from typing import Callable, Iterable, Protocol, Tuple, TypeVar, Union, runtime_checkable + +from torch import nn + + +@runtime_checkable +class HasBool(Protocol): + def __bool__(self) -> bool: + ... + + +_TModule = TypeVar("_TModule", bound=nn.Module) +ModuleFunc = Callable[[nn.Module], nn.Module] +ModuleIFunc = Callable[[nn.Module, int], nn.Module] +ModulePredicate = Callable[[nn.Module], Union[bool, HasBool]] +ModuleMapFunc = Callable[[nn.Module], Union[nn.Module, None]] + + +class ContainerMixin: + """Mixin that can be used to give a class container-like behavior. + + This mixin provides a set of methods that come from functional programming. + + """ + + def filter(self: _TModule, func: ModulePredicate, recurse: bool = False) -> _TModule: + """ + Returns a new container with modules that satisfy the filtering function. + + Example usage:: + >>> block = Block(nn.LazyLinear(10)) + >>> block.filter(lambda module: isinstance(module, nn.Linear)) + Block(nn.Linear(10, 10)) + + Parameters + ---------- + func (Callable[[Module], bool]): A function that takes a module and returns + a boolean or a boolean-like object. + recurse (bool, optional): Whether to recursively filter modules + within sub-containers. Default is False. + + Returns + ------- + Self: A new container with the filtered modules. + """ + + _to_call = _recurse(func, "filter") if recurse else func + output = self.__class__() + + for module in self: + filtered = _to_call(module) + if filtered: + if isinstance(filtered, bool): + output.append(module) + else: + output.append(filtered) + + return output + + def flatmap(self: _TModule, func: ModuleFunc) -> _TModule: + """ + Applies a function to each module and flattens the results into a new container. + + Example usage:: + >>> block = Block(nn.LazyLinear(10)) + >>> container.flatmap(lambda module: [module, module]) + Block(nn.LazyLinear(10), nn.LazyLinear(10)) + + Parameters + ---------- + func : Callable[[Module], Iterable[Module]] + A function that takes a module and returns an iterable of modules. + + Returns + ------- + Self + A new container with the flattened modules. + + Raises + ------ + TypeError + If the input function is not callable. + RuntimeError + If an exception occurs during mapping the function over the module. + """ + + if not callable(func): + raise TypeError(f"Expected callable function, received: {type(func).__name__}") + + try: + mapped = self.map(func) + except Exception as e: + raise RuntimeError("Failed to map function over the module") from e + + output = self.__class__() + + try: + for sublist in mapped: + for item in sublist: + output.append(item) + except TypeError as e: + raise TypeError("Function did not return an iterable object") from e + + return output + + def forall(self, func: ModulePredicate, recurse: bool = False) -> bool: + """ + Checks if the given predicate holds for all modules in the container. + + Example usage:: + >>> block = Block(nn.LazyLinear(10)) + >>> container.forall(lambda module: isinstance(module, nn.Module)) + True + + Parameters + ---------- + func : Callable[[Module], bool] + A predicate function that takes a module and returns True or False. + recurse : bool, optional + Whether to recursively check modules within sub-containers. Default is False. + + Returns + ------- + bool + True if the predicate holds for all modules, False otherwise. + + + """ + _to_call = _recurse(func, "forall") if recurse else func + return all(_to_call(module) for module in self) + + def map(self: _TModule, func: ModuleFunc, recurse: bool = False) -> _TModule: + """ + Applies a function to each module and returns a new container with the results. + + Example usage:: + >>> block = Block(nn.LazyLinear(10)) + >>> container.map(lambda module: nn.ReLU()) + Block(nn.ReLU()) + + Parameters + ---------- + func : Callable[[Module], Module] + A function that takes a module and returns a modified module. + recurse : bool, optional + Whether to recursively map the function to modules within sub-containers. + Default is False. + + Returns + ------- + _TModule + A new container with the mapped modules. + """ + + _to_call = _recurse(func, "map") if recurse else func + + return self.__class__(*(_to_call(module) for module in self)) + + def mapi(self: _TModule, func: ModuleIFunc, recurse: bool = False) -> _TModule: + """ + Applies a function to each module along with its index and + returns a new container with the results. + + Example usage:: + >>> block = Block(nn.LazyLinear(10), nn.LazyLinear(10)) + >>> container.mapi(lambda module, i: module if i % 2 == 0 else nn.ReLU()) + Block(nn.LazyLinear(10), nn.ReLU()) + + Parameters + ---------- + func : Callable[[Module, int], Module] + A function that takes a module and its index, + and returns a modified module. + recurse : bool, optional + Whether to recursively map the function to modules within + sub-containers. Default is False. + + Returns + ------- + Self + A new container with the mapped modules. + """ + + _to_call = _recurse(func, "mapi") if recurse else func + return self.__class__(*(_to_call(module, i) for i, module in enumerate(self))) + + def choose(self: _TModule, func: ModuleMapFunc, recurse: bool = False) -> _TModule: + """ + Returns a new container with modules that are selected by the given function. + + Example usage:: + >>> block = Block(nn.LazyLinear(10), nn.Relu()) + >>> container.choose(lambda m: m if isinstance(m, nn.Linear) else None) + Block(nn.LazyLinear(10)) + + Parameters + ---------- + func : Callable[[Module], Union[Module, None]] + A function that takes a module and returns a module or None. + recurse : bool, optional + Whether to recursively choose modules within sub-containers. Default is False. + + Returns + ------- + Self + A new container with the chosen modules. + """ + + to_add = [] + _to_call = _recurse(func, "choose") if recurse else func + + for module in self: + f_out = _to_call(module) + if f_out: + to_add.append(f_out) + + return self.__class__(*to_add) + + def walk(self: _TModule, func: ModulePredicate) -> _TModule: + """ + Applies a function to each module recursively and returns + a new container with the results. + + Example usage:: + >>> block = Block(Block(nn.LazyLinear(10), nn.ReLU())) + >>> block.walk(lambda m: m if isinstance(m, nn.ReLU) else None) + Block(Block(nn.ReLU())) + + Parameters + ---------- + func : Callable[[Module], Module] + A function that takes a module and returns a modified module. + + Returns + ------- + Self + A new container with the walked modules. + """ + + return self.map(func, recurse=True) + + def zip(self, other: Iterable[_TModule]) -> Iterable[Tuple[_TModule, _TModule]]: + """ + Zips the modules of the container with the modules from another iterable into pairs. + + Example usage:: + >>> list(Block(nn.Linear(10)).zip(Block(nn.Linear(10)))) + [(nn.Linear(10), nn.Linear(10))] + + Parameters + ---------- + other : Iterable[Self] + Another iterable containing modules to be zipped with. + + Returns + ------- + Iterable[Tuple[Self, Self]] + An iterable of pairs containing modules from the container + and the other iterable. + """ + + return builtins.zip(self, other) + + def freeze(self) -> None: + """ + Freezes the parameters of all modules in the container + by setting `requires_grad` to False. + """ + for param in self.parameters(): + param.requires_grad = False + + def unfreeze(self) -> None: + """ + Unfreezes the parameters of all modules in the container + by setting `requires_grad` to True. + """ + for param in self.parameters(): + param.requires_grad = True + + def __add__(self, module) -> _TModule: + if hasattr(module, "__iter__"): + return self.__class__(*self, *module) + + return self.__class__(*self, module) + + def __radd__(self, module) -> _TModule: + if hasattr(module, "__iter__"): + return self.__class__(*module, *self) + + return self.__class__(module, *self) + + +def map( + module: _TModule, + func: ModuleFunc, + recurse: bool = False, + parameterless_modules_only=False, + **kwargs, +) -> _TModule: + """ + Applies a transformation function to a module or a collection of modules. + + Parameters + ---------- + module : nn.Module + The module or collection of modules to which the function will be applied. + func : ModuleFunc + The function that will be applied to the modules. + recurse : bool, optional + Whether to apply the function recursively to child modules. + parameterless_modules_only : bool, optional + Whether to apply the function only to modules without parameters. + **kwargs : dict + Additional keyword arguments that will be passed to the transformation function. + + Returns + ------- + type(module) + The transformed module or collection of modules. + """ + if hasattr(module, "map"): + to_call = module.map + elif isinstance(module, Iterable): + # Check if the module has .items() method (for dict-like modules) + if hasattr(module, "items"): + to_call = map_module_dict + else: + to_call = map_module_list + else: + to_call = map_module + + return to_call( + module, + func, + parameterless_modules_only=parameterless_modules_only, + recurse=recurse, + **kwargs, + ) + + +def walk( + module: _TModule, func: ModuleFunc, parameterless_modules_only=False, **kwargs +) -> _TModule: + """ + Applies a transformation function recursively to a module or a collection of modules. + + Parameters + ---------- + module : nn.Module + The module or collection of modules to which the function will be applied. + func : ModuleFunc + The function that will be applied to the modules. + parameterless_modules_only : bool, optional + Whether to apply the function only to modules without parameters. + **kwargs : dict + Additional keyword arguments that will be passed to the transformation function. + + Returns + ------- + type(module) + The transformed module or collection of modules. + """ + return map( + module, func, recurse=True, parameterless_modules_only=parameterless_modules_only, **kwargs + ) + + +def map_module( + module: _TModule, func: ModuleFunc, recurse=False, parameterless_modules_only=False, **kwargs +) -> _TModule: + """ + Applies a transformation function to a module and optionally to its child modules. + + Parameters + ---------- + module : nn.Module + The module to which the function will be applied. + func : ModuleFunc + The function that will be applied to the module. + recurse : bool, optional + Whether to apply the function recursively to child modules. + parameterless_modules_only : bool, optional + Whether to apply the function only to modules without parameters. + **kwargs : dict + Additional keyword arguments that will be passed to the transformation function. + + Returns + ------- + nn.Module + The transformed module. + """ + if list(module.parameters(recurse=False)): + new_module = module + else: + new_module = deepcopy(module) + + f_kwargs = _get_func_kwargs(func, **kwargs) + new_module = func(new_module, **f_kwargs) + + if new_module is not module and recurse: + for i, (name, child) in enumerate(module.named_children()): + setattr( + new_module, + name, + map(child, func, recurse, parameterless_modules_only, i=i, name=name), + ) + + return new_module + + +def map_module_list( + module_list: _TModule, func, recurse=False, parameterless_modules_only=False, **kwargs +) -> _TModule: + mapped_modules = [] + for i, module in enumerate(module_list): + new_module = map( + module, + func, + recurse=recurse, + parameterless_modules_only=parameterless_modules_only, + i=i, + name=str(i), + **kwargs, + ) + mapped_modules.append(new_module) + + return _create_list_wrapper(module_list, mapped_modules) + + +def map_module_dict( + module_dict: _TModule, + func: ModuleFunc, + recurse: bool = False, + parameterless_modules_only: bool = False, + **kwargs, +) -> _TModule: + """ + Applies a transformation function to a ModuleDict of modules. + + Parameters + ---------- + module_dict : nn.ModuleDict + The ModuleDict of modules to which the function will be applied. + func : ModuleFunc + The function that will be applied to the modules. + recurse : bool, optional + Whether to apply the function recursively to child modules. + parameterless_modules_only : bool, optional + Whether to apply the function only to modules without parameters. + **kwargs : dict + Additional keyword arguments that will be passed to the transformation function. + + Returns + ------- + nn.ModuleDict + The ModuleDict of transformed modules. + """ + + # Map the function to each module in the dictionary + mapped_modules = {} + for i, (name, module) in enumerate(module_dict.items()): + mapped_modules[name] = map( + module, + func, + recurse=recurse, + parameterless_modules_only=parameterless_modules_only, + name=name, + i=i, + **kwargs, + ) + + return type(module_dict)(mapped_modules) + + +def _create_list_wrapper(module_list, to_add): + # Check the signature of the type constructor + sig = inspect.signature(type(module_list).__init__) + if "args" in sig.parameters: + return type(module_list)(*to_add) # Unpack new_modules + + return type(module_list)(to_add) # Don't unpack new_modules + + +def _get_func_kwargs(func, **kwargs): + sig = inspect.signature(func) + f_kwargs = {} + if "i" in sig.parameters and "i" in kwargs: + f_kwargs["i"] = kwargs["i"] + if "name" in sig.parameters and "name" in kwargs: + f_kwargs["name"] = kwargs["name"] + + return f_kwargs + + +def _recurse(func, to_recurse_name: str): + @wraps(func) + def inner(module, *args, **kwargs): + if hasattr(module, to_recurse_name): + fn = getattr(module, to_recurse_name) + + return fn(func, *args, **kwargs) + + return func(module, *args, **kwargs) + + return inner diff --git a/merlin/models/torch/inputs/embedding.py b/merlin/models/torch/inputs/embedding.py index fe8cf28170..b49965699a 100644 --- a/merlin/models/torch/inputs/embedding.py +++ b/merlin/models/torch/inputs/embedding.py @@ -82,9 +82,12 @@ def __init__( self.has_combiner = self.seq_combiner is not None self.has_module_combiner = isinstance(self.seq_combiner, nn.Module) self.num_embeddings = 0 - self.setup_schema(schema or Schema()) + self.input_schema = None + if schema: + self.initialize_from_schema(schema or Schema()) + self._initialized_from_schema = True - def setup_schema(self, schema: Schema): + def initialize_from_schema(self, schema: Schema): """ Sets up the schema for the embedding table. @@ -374,6 +377,14 @@ def update_feature(self, col_schema: ColumnSchema) -> "EmbeddingTable": return self + def feature_weights(self, name: str): + if name not in self.domains: + raise ValueError(f"{name} not found in table: {self}") + + domain = self.domains[name] + + return self.table.weight[int(domain.min) : int(domain.max)] + def select(self, selection: Selection) -> Selectable: selected = select(self.input_schema, selection) @@ -454,17 +465,17 @@ def __init__( self.seq_combiner = seq_combiner self.kwargs = kwargs if isinstance(schema, Schema): - self.setup_schema(schema) - - def setup_schema(self, schema: Schema): - """ - Sets up the schema for the embedding tables. + self.initialize_from_schema(schema) + self._initialized_from_schema = True - Args: - schema (Schema): The schema to setup. + def initialize_from_schema(self, schema: Schema): + """Initializes the module from a schema. + Called during the schema tracing of the model. - Returns: - EmbeddingTables: The updated EmbeddingTables instance with the setup schema. + Parameters + ---------- + schema : Schema + The schema to initialize with """ self.schema = schema diff --git a/merlin/models/torch/inputs/select.py b/merlin/models/torch/inputs/select.py index 6456e807a6..206083b838 100644 --- a/merlin/models/torch/inputs/select.py +++ b/merlin/models/torch/inputs/select.py @@ -14,7 +14,7 @@ # limitations under the License. # -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import torch from torch import nn @@ -24,7 +24,7 @@ from merlin.schema import ColumnSchema, Schema, Tags -class SelectKeys(nn.Module, schema.Selectable): +class SelectKeys(nn.Module, schema.Selectable, schema.LazySchemaModuleMixin): """Filter tabular data based on a defined schema. Example usage:: @@ -50,20 +50,24 @@ class SelectKeys(nn.Module, schema.Selectable): List of column names in the schema. """ - def __init__(self, schema: Optional[Schema] = None): + def __init__(self, schema: Optional[Union[Schema, ColumnSchema]] = None): super().__init__() - self.column_names: List[str] = [] - if schema: - self.setup_schema(schema) - - def setup_schema(self, schema: Schema): if isinstance(schema, ColumnSchema): schema = Schema([schema]) - + if schema: + self.initialize_from_schema(schema) + self._initialized_from_schema = True + else: + schema = Schema() + self.schema = schema + self.column_names: List[str] = schema.column_names + + def initialize_from_schema(self, schema: Schema): + super().initialize_from_schema(schema) self.schema = schema + self.column_names = schema.column_names self.input_schema = schema self.output_schema = schema - self.column_names = schema.column_names def select(self, selection: schema.Selection) -> "SelectKeys": """Select a subset of the schema based on the provided selection. @@ -125,7 +129,7 @@ def __eq__(self, other) -> bool: return set(self.column_names) == set(other.column_names) -class SelectFeatures(nn.Module): +class SelectFeatures(nn.Module, schema.LazySchemaModuleMixin): """Filter tabular data based on a defined schema. It operates similarly to SelectKeys, but it uses the features from Batch. @@ -154,9 +158,9 @@ def __init__(self, schema: Optional[Schema] = None): super().__init__() self.select_keys = SelectKeys(schema=schema) if schema: - self.setup_schema(schema) + self.initialize_from_schema(schema) - def setup_schema(self, schema: Schema): + def initialize_from_schema(self, schema: Schema): """Set up the schema for the SelectFeatures. Parameters @@ -164,11 +168,11 @@ def setup_schema(self, schema: Schema): schema : Schema The schema to use for selection. """ - self.select_keys.setup_schema(schema) + super().initialize_from_schema(schema) + self.select_keys.initialize_from_schema(schema) self.embedding_names = schema.select_by_tag(Tags.EMBEDDING).column_names - self.input_schema = self.select_keys.input_schema - self.feature_schema = self.input_schema - self.output_schema = self.select_keys.output_schema + self.input_schema = Schema() + self.output_schema = schema def select(self, selection: schema.Selection) -> "SelectFeatures": """Select a subset of the schema based on the provided selection. @@ -187,6 +191,9 @@ def select(self, selection: schema.Selection) -> "SelectFeatures": return SelectFeatures(schema) + def compute_feature_schema(self, feature_schema: Schema) -> Schema: + return feature_schema[self.select_keys.column_names] + def forward(self, inputs, batch: Batch) -> Dict[str, torch.Tensor]: outputs = {} selected = self.select_keys(batch.features) @@ -201,8 +208,8 @@ def forward(self, inputs, batch: Batch) -> Dict[str, torch.Tensor]: @schema.extract.register(SelectKeys) def _(main, selection, route, name=None): - main_schema = schema.input(main) - route_schema = schema.input(route) + main_schema = schema.input_schema(main) + route_schema = schema.input_schema(route) diff = main_schema.excluding_by_name(route_schema.column_names) diff --git a/merlin/models/torch/inputs/tabular.py b/merlin/models/torch/inputs/tabular.py index 0d73656e20..fb1ed3ac6a 100644 --- a/merlin/models/torch/inputs/tabular.py +++ b/merlin/models/torch/inputs/tabular.py @@ -14,13 +14,15 @@ # limitations under the License. # -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Sequence, Union from torch import nn from merlin.models.torch.block import Block from merlin.models.torch.inputs.embedding import EmbeddingTables from merlin.models.torch.router import RouterBlock +from merlin.models.torch.schema import Selection, select, select_union +from merlin.models.torch.transforms.sequences import BroadcastToSequence from merlin.models.utils.registry import Registry from merlin.schema import Schema, Tags @@ -51,22 +53,26 @@ class TabularInputBlock(RouterBlock): def __init__( self, - schema: Schema, + schema: Optional[Schema] = None, init: Optional[Union[str, Initializer]] = None, agg: Optional[Union[str, nn.Module]] = None, ): + self.init = init super().__init__(schema) - self.schema: Schema = self.selectable.schema - if init: - if isinstance(init, str): - init = self.initializers.get(init) - if not init: - raise ValueError(f"Initializer {init} not found.") - - init(self) if agg: self.append(Block.parse(agg)) + def initialize_from_schema(self, schema: Schema): + super().initialize_from_schema(schema) + self.schema: Schema = self.selectable.schema + if self.init: + if isinstance(self.init, str): + self.init = self.initializers.get(self.init) + if not self.init: + raise ValueError(f"Initializer {self.init} not found.") + + self.init(self) + @classmethod def register_init(cls, name: str): """ @@ -91,7 +97,7 @@ def defaults(block: TabularInputBlock): @TabularInputBlock.register_init("defaults") -def defaults(block: TabularInputBlock): +def defaults(block: TabularInputBlock, seq_combiner="mean"): """ Default initializer function for a TabularInputBlock. @@ -101,5 +107,69 @@ def defaults(block: TabularInputBlock): Args: block (TabularInputBlock): The block to initialize. """ - block.add_route(Tags.CONTINUOUS) - block.add_route(Tags.CATEGORICAL, EmbeddingTables(seq_combiner="mean")) + block.add_route(Tags.CONTINUOUS, required=False) + block.add_route(Tags.CATEGORICAL, EmbeddingTables(seq_combiner=seq_combiner)) + + +@TabularInputBlock.register_init("defaults-no-combiner") +def defaults_no_combiner(block: TabularInputBlock): + return defaults(block, seq_combiner=None) + + +@TabularInputBlock.register_init("broadcast-context") +def defaults_broadcast_to_seq( + block: TabularInputBlock, + seq_selection: Selection = Tags.SEQUENCE, + feature_selection: Sequence[Selection] = (Tags.CATEGORICAL, Tags.CONTINUOUS), +): + context_selection = _not_seq(seq_selection, feature_selection=feature_selection) + block.add_route(context_selection, TabularInputBlock(init="defaults"), name="context") + block.add_route( + seq_selection, + TabularInputBlock(init="defaults-no-combiner"), + name="sequence", + ) + block.append(BroadcastToSequence(context_selection, seq_selection, block.schema)) + + +def stack_context( + model_dim: int, + seq_selection: Selection = Tags.SEQUENCE, + projection_activation=None, + feature_selection: Sequence[Selection] = (Tags.CATEGORICAL, Tags.CONTINUOUS), +): + def init_stacked_context(block: TabularInputBlock): + import merlin.models.torch as mm + + mlp_kwargs = {"units": [model_dim], "activation": projection_activation} + context_selection = _not_seq(seq_selection, feature_selection=feature_selection) + context = TabularInputBlock(select(block.schema, context_selection)) + context.add_route(Tags.CATEGORICAL, EmbeddingTables(seq_combiner=None)) + context.add_route(Tags.CONTINUOUS, mm.MLPBlock(**mlp_kwargs)) + context["categorical"].append_for_each(mm.MLPBlock(**mlp_kwargs)) + context.append(mm.Stack(dim=1)) + + block.add_route(context.schema, context, name="context") + block.add_route( + seq_selection, + TabularInputBlock(init="defaults-no-combiner", agg=mm.Concat(dim=2)), + name="sequence", + ) + + return init_stacked_context + + +def _not_seq( + seq_selection: Sequence[Selection], + feature_selection: Sequence[Selection] = (Tags.CATEGORICAL, Tags.CONTINUOUS), +) -> Selection: + if not isinstance(seq_selection, (tuple, list)): + seq_selection = (seq_selection,) + + def select_non_seq(schema: Schema) -> Schema: + seq = select_union(*seq_selection)(schema) + features = select_union(*feature_selection)(schema) + + return features - seq + + return select_non_seq diff --git a/merlin/models/torch/models/base.py b/merlin/models/torch/models/base.py index df1826746c..4455c4ba2d 100644 --- a/merlin/models/torch/models/base.py +++ b/merlin/models/torch/models/base.py @@ -13,20 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict, List, Optional, Sequence, Union +import inspect +import itertools +import os +from typing import Dict, Iterator, List, Optional, Sequence, Type, Union import torch -from pytorch_lightning import LightningModule -from torch import nn +from packaging import version +from pytorch_lightning import LightningDataModule, LightningModule +from torch import nn, optim from merlin.dataloader.torch import Loader from merlin.io import Dataset from merlin.models.torch.batch import Batch -from merlin.models.torch.block import Block +from merlin.models.torch.block import BatchBlock, Block from merlin.models.torch.outputs.base import ModelOutput from merlin.models.torch.utils import module_utils from merlin.models.utils.registry import camelcase_to_snakecase +OptimizerType = Union[optim.Optimizer, Type[optim.Optimizer], str] + +LRScheduler = ( + optim.lr_scheduler._LRScheduler + if version.parse(torch.__version__).major < 2 + else optim.lr_scheduler.LRScheduler +) + +LRSchedulerType = Union[LRScheduler, Type[LRScheduler]] + class Model(LightningModule, Block): """ @@ -42,8 +56,11 @@ class Model(LightningModule, Block): schema: Schema, optional A Merlin schema. Default is None. optimizer: torch.optim.Optimizer, optional - A PyTorch optimizer from the PyTorch library (or any custom optimizer + A PyTorch optimizer instance or class from the PyTorch library (or any custom optimizer that follows the same API). Default is Adam optimizer. + scheduler: torch.optim.lr_scheduler.LRScheduler, optional + A PyTorch learning rate scheduler instance from the PyTorch library (or any custom scheduler + that follows the same API). Default is None, which means no LR decay. Example usage ------------- @@ -53,15 +70,15 @@ class Model(LightningModule, Block): ... BinaryOutput(schema.select_by_tag(Tags.TARGET).first), ... ) ... trainer = Trainer(max_epochs=1) - ... with Loader(dataset, batch_size=16) as loader: - ... model.initialize(loader) - ... trainer.fit(model, loader) + ... trainer.fit(model, Loader(dataset, batch_size=16)) """ def __init__( self, *blocks: nn.Module, optimizer=torch.optim.Adam, + initialization="auto", + pre: Optional[BatchBlock] = None, ): super().__init__() @@ -69,8 +86,51 @@ def __init__( self.values = nn.ModuleList() for module in blocks: self.values.append(self.wrap_module(module)) + self.initialization = initialization + if isinstance(pre, BatchBlock): + self.pre = pre + elif pre is None: + self.pre = BatchBlock() + else: + raise ValueError(f"Invalid pre: {pre}, must be a BatchBlock") + + @property + @torch.jit.ignore + def optimizer(self): + return self._optimizer if hasattr(self, "_optimizer") else None + + def configure_optimizers( + self, + optimizer: Optional[OptimizerType] = None, + scheduler: Optional[LRSchedulerType] = None, + ): + """Configures the optimizer for the model.""" + if optimizer is None: + optimizer = self._optimizer if hasattr(self, "_optimizer") else "adam" + self._optimizer = create_optimizer(self, optimizer) - self.optimizer = optimizer + if scheduler is None: + if hasattr(self, "_scheduler"): + scheduler = self._scheduler + else: + self._scheduler = None + if scheduler is not None: + self._scheduler = get_scheduler(self._optimizer, scheduler) + + if not isinstance(self._optimizer, (list, tuple)): + opt = [self._optimizer] + else: + opt = self._optimizer + + if self._scheduler is not None: + if not isinstance(self._scheduler, (list, tuple)): + sched = [self._scheduler] + else: + sched = self._scheduler + + return opt, sched + + return opt def initialize(self, data: Union[Dataset, Loader, Batch]): """Initializes the model based on a given data set.""" @@ -80,9 +140,11 @@ def forward( self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None ): """Performs a forward pass through the model.""" - outputs = inputs + _batch: Batch = self.pre(inputs, batch=batch) + + outputs = _batch.inputs() for block in self.values: - outputs = block(outputs, batch=batch) + outputs = block(outputs, batch=_batch) return outputs def training_step(self, batch, batch_idx): @@ -96,15 +158,32 @@ def training_step(self, batch, batch_idx): predictions = self(features, batch=Batch(features, targets)) - loss_and_metrics = compute_loss(predictions, targets, self.model_outputs()) + loss_and_metrics = compute_loss( + predictions, targets, self.model_outputs(), compute_metrics=True + ) for name, value in loss_and_metrics.items(): self.log(f"train_{name}", value) return loss_and_metrics["loss"] - def configure_optimizers(self): - """Configures the optimizer for the model.""" - return self.optimizer(self.parameters()) + def validation_step(self, batch, batch_idx): + return self._val_step(batch, batch_idx, type="val") + + def test_step(self, batch, batch_idx): + return self._val_step(batch, batch_idx, type="test") + + def _val_step(self, batch, batch_idx, type="val"): + del batch_idx + if not isinstance(batch, Batch): + batch = Batch(features=batch[0], targets=batch[1]) + + predictions = self(batch.features, batch=batch) + + loss_and_metrics = compute_loss(predictions, batch.targets, self.model_outputs()) + for name, value in loss_and_metrics.items(): + self.log(f"{type}_{name}", value) + + return loss_and_metrics def model_outputs(self) -> List[ModelOutput]: """Finds all instances of `ModelOutput` in the model.""" @@ -118,6 +197,150 @@ def last(self) -> nn.Module: """Returns the last block in the model.""" return self.values[-1] + def setup(self, stage): + """Initialize the model if `initialization="auto"`.""" + if self.initialization == "auto": + loop = getattr(self.trainer, f"{stage}_loop") + + data_instance = loop._data_source.instance + if isinstance(data_instance, MultiLoader): + self.initialize(data_instance.batch.to(None, device=self.device)) + else: + dataloader = loop._data_source.dataloader() + if isinstance(dataloader, Loader): + self.initialize(dataloader) + else: + raise ValueError( + f"Can't auto-initialize from a non-merlin dataloader, got: {dataloader}", + "Please initialize the model manually with `model.initialize(batch)`", + ) + + def teardown(self, stage: str) -> None: + """Teardown the data-loader after training.""" + loop = getattr(self.trainer, f"{stage}_loop") + dataloader = loop._data_source.dataloader() + if isinstance(dataloader, Loader): + dataloader.stop() + + +class MultiLoader(LightningDataModule): + """ + Data Module for handling multiple types of data loaders. It facilitates the usage + of multiple datasets, as well as distributed training on multiple GPUs. + + This class is particularly useful in scenarios where you have separate train, + validation and test datasets, and you want to use PyTorch Lightning's Trainer + which requires a single DataModule. + + Parameters + ---------- + train : Union[Dataset, Loader] + Training dataset or data loader. + valid : Optional[Union[Dataset, Loader]], optional + Validation dataset or data loader, by default None + test : Optional[Union[Dataset, Loader]], optional + Test dataset or data loader, by default None + repartition : int, optional + Number of partitions to divide the dataset into, by default None + batch_size : int, optional + Number of data points per batch, by default 1024 + + + Example usage for multi-GPU:: + model = mm.Model(...) + train, valid = generate_data(...) + model.initialize(train) + + trainer = pl.Trainer(max_epochs=5, devices=[0, 1]) + trainer.fit(model, mm.MultiLoader(train, valid, batch_size=1024, repartition=4)) + """ + + def __init__( + self, + train: Union[Dataset, Loader], + valid: Optional[Union[Dataset, Loader]] = None, + test: Optional[Union[Dataset, Loader]] = None, + batch_size: int = 1024, + repartition: Optional[int] = None, + ): + super().__init__() + self.repartition = repartition + self.train = train + self.batch_size = batch_size + self.batch = Batch.sample_from(train, batch_size=1, shuffle=False) + if valid: + self.val_dataloader = lambda: self._create_loader(valid, "valid") + if test: + self.test_dataloader = lambda: self._create_loader(test, "test") + + def train_dataloader(self) -> Loader: + return self._create_loader(self.train, "train") + + def _create_loader(self, data: Union[Dataset, Loader], name: str) -> Loader: + """ + Create a data loader with the right arguments. + + Parameters + ---------- + data : Union[Dataset, Loader] + The input data, can be a dataset or data loader. + name : str + Name of the data loader. + + Returns + ------- + Loader + The created data loader. + """ + + _dataset = data.dataset if isinstance(data, Loader) else data + + has_world_size = "WORLD_SIZE" in os.environ + + if self.repartition: + npartitions = self.repartition + elif has_world_size: + npartitions = int(os.environ["WORLD_SIZE"]) + elif isinstance(data, Loader): + npartitions = data.global_size + else: + npartitions = None + + if npartitions: + _dataset = _dataset.repartition(npartitions=npartitions) + + if isinstance(data, Loader): + output = Loader( + _dataset, + batch_size=data.batch_size, + shuffle=data.shuffle, + drop_last=int(os.environ["WORLD_SIZE"]) > 1 if has_world_size else data.drop_last, + global_size=int(os.environ["WORLD_SIZE"]) if has_world_size else data.global_size, + global_rank=int(os.environ["LOCAL_RANK"]) if has_world_size else data.global_rank, + transforms=data.transforms, + ) + else: + output = Loader( + _dataset, + batch_size=self.batch_size, + drop_last=int(os.environ["WORLD_SIZE"]) > 1 if has_world_size else False, + global_size=int(os.environ["WORLD_SIZE"]) if has_world_size else None, + global_rank=int(os.environ["LOCAL_RANK"]) if has_world_size else None, + ) + + setattr(self, f"loader_{name}", output) + return output + + def teardown(self, stage): + """ + Stop all data loaders. + """ + for attr in dir(self): + if attr.startswith("loader"): + if hasattr(getattr(self, attr), "stop"): + getattr(self, attr).stop() + delattr(self, attr) + def compute_loss( predictions: Union[torch.Tensor, Dict[str, torch.Tensor]], @@ -210,5 +433,111 @@ def compute_loss( for metric in model_out.metrics: metric_name = camelcase_to_snakecase(metric.__class__.__name__) + if not metric.device or metric.device != _predictions.device: + metric = metric.to(_predictions.device) + results[metric_name] = metric(_predictions, _targets) return results + + +def create_optimizer(module: nn.Module, opt: OptimizerType) -> optim.Optimizer: + """ + Creates an optimizer given a PyTorch module and an optimizer type. + + Parameters + ---------- + module : torch.nn.Module + The PyTorch model. + opt : str, Type[torch.optim.Optimizer], or torch.optim.Optimizer + The optimizer type, either as a string, a class, or an existing + PyTorch optimizer object. + + Returns + ------- + torch.optim.Optimizer + A PyTorch optimizer. + + Raises + ------ + ValueError + If the provided string for opt does not correspond to a known optimizer type. + TypeError + If the type of opt is neither string, class of torch.optim.Optimizer, + nor instance of torch.optim.Optimizer. + """ + + # Extract the model parameters + params = module.parameters() + + # If opt is a string, create a new optimizer of the given type + if isinstance(opt, str): + if opt.lower() == "sgd": + return optim.SGD(params, lr=0.01) + elif opt.lower() == "adam": + return optim.Adam(params, lr=0.001) + elif opt.lower() == "adagrad": + return optim.Adagrad(params, lr=0.01) + else: + raise ValueError(f"Unsupported optimizer type: {opt}") + + # If opt is an optimizer class, create a new optimizer of the given type + elif isinstance(opt, type) and issubclass(opt, optim.Optimizer): + return opt(params, lr=0.01) + + # If opt is an optimizer instance, create a new optimizer of the same type + elif isinstance(opt, optim.Optimizer): + # Flattens a list of lists (or other iterable) + def flatten(lis: Iterator[Iterator]) -> Iterator: + return list(itertools.chain.from_iterable(lis)) + + # Extract parameters from optimizer's param_groups + params_opt = flatten([group["params"] for group in opt.param_groups]) + params_module = list(module.parameters()) + + # Check if the parameters of the module and the optimizer are the same + if params_module == params_opt: + # If parameters are the same, return the existing optimizer + return opt + else: + # If parameters are not the same, create a new optimizer of the same type + opt_type = type(opt) + return opt_type(params_module, **opt.defaults) + + raise TypeError( + "Expected opt to be a string, a class of torch.optim.Optimizer, ", + f"or an instance of torch.optim.Optimizer, but got {type(opt)}", + ) + + +def get_scheduler(optimizer: optim.Optimizer, scheduler: LRSchedulerType) -> LRScheduler: + """ + Get an instance of a learning rate scheduler. + + Parameters + ---------- + optimizer : torch.optim.Optimizer + The optimizer to which the scheduler should be applied. + scheduler : SchedulerType + The scheduler or scheduler class to use. + If an instance is provided and its optimizer is different from the provided optimizer: + a new instance of the same type is returned with the provided optimizer. + If the optimizers are the same: the original scheduler is returned. + If a class is provided: an instance is created with the optimizer as the only argument. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The scheduler instance. + """ + if isinstance(scheduler, LRScheduler): + if scheduler.optimizer != optimizer: + return type(scheduler)(optimizer) + else: + return scheduler + elif inspect.isclass(scheduler) and issubclass(scheduler, LRScheduler): + return scheduler(optimizer) + + raise TypeError( + "scheduler must be a subclass or instance of optim.lr_scheduler.LRScheduler ", + f"got: {scheduler}", + ) diff --git a/merlin/models/torch/models/ranking.py b/merlin/models/torch/models/ranking.py index 292abebbd8..2bd9eb895c 100644 --- a/merlin/models/torch/models/ranking.py +++ b/merlin/models/torch/models/ranking.py @@ -2,13 +2,19 @@ from torch import nn -from merlin.models.torch.block import Block -from merlin.models.torch.blocks.dlrm import DLRMBlock +from merlin.models.torch.block import Block, ParallelBlock +from merlin.models.torch.blocks.cross import _DCNV2_REF, CrossBlock +from merlin.models.torch.blocks.dlrm import _DLRM_REF, DLRMBlock +from merlin.models.torch.blocks.mlp import MLPBlock +from merlin.models.torch.inputs.tabular import TabularInputBlock from merlin.models.torch.models.base import Model from merlin.models.torch.outputs.tabular import TabularOutputBlock +from merlin.models.torch.transforms.agg import Concat, MaybeAgg +from merlin.models.utils.doc_utils import docstring_parameter from merlin.schema import Schema +@docstring_parameter(dlrm_reference=_DLRM_REF) class DLRMModel(Model): """ The Deep Learning Recommendation Model (DLRM) as proposed in Naumov, et al. [1] @@ -42,15 +48,12 @@ class DLRMModel(Model): ... schema, ... dim=64, ... bottom_block=mm.MLPBlock([256, 64]), - ... output_block=BinaryOutput(ColumnSchema("target"))) + ... output_block=mm.BinaryOutput(ColumnSchema("target")), + ... ) >>> trainer = pl.Trainer() - >>> model.initialize(dataloader) - >>> trainer.fit(model, dataloader) + >>> trainer.fit(model, Loader(dataset, batch_size=32)) - References - ---------- - [1] Naumov, Maxim, et al. "Deep learning recommendation model for - personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019). + {dlrm_reference} """ def __init__( @@ -74,3 +77,74 @@ def __init__( ) super().__init__(dlrm_body, output_block) + + +@docstring_parameter(dcn_reference=_DCNV2_REF) +class DCNModel(Model): + """ + The Deep & Cross Network (DCN) architecture as proposed in Wang, et al. [1] + + Parameters + ---------- + schema : Schema + The schema to use for selection. + depth : int, optional + Number of cross-layers to be stacked, by default 1 + deep_block : Block, optional + The `Block` to use as the deep part of the model (typically a `MLPBlock`) + stacked : bool + Whether to use the stacked version of the model or the parallel version. + input_block : Block, optional + The `Block` to use as the input layer. If None, a default `TabularInputBlock` object + is instantiated, that creates the embedding tables for the categorical features + based on the schema. The embedding dimensions are inferred from the features + cardinality. For a custom representation of input data you can instantiate + and provide a `TabularInputBlock` instance. + + Returns + ------- + Model + An instance of Model class representing the fully formed DCN. + + Example usage + ------------- + >>> model = mm.DCNModel( + ... schema, + ... depth=2, + ... deep_block=mm.MLPBlock([256, 64]), + ... output_block=mm.BinaryOutput(ColumnSchema("target")), + ... ) + >>> trainer = pl.Trainer() + >>> model.initialize(dataloader) + >>> trainer.fit(model, dataloader) + + {dcn_reference} + """ + + def __init__( + self, + schema: Schema, + depth: int = 1, + deep_block: Optional[Block] = None, + stacked: bool = True, + input_block: Optional[Block] = None, + output_block: Optional[Block] = None, + ) -> None: + if input_block is None: + input_block = TabularInputBlock(schema, init="defaults") + + if output_block is None: + output_block = TabularOutputBlock(schema, init="defaults") + + if deep_block is None: + deep_block = MLPBlock([512, 256]) + + if stacked: + cross_network = Block(CrossBlock.with_depth(depth), deep_block) + else: + cross_network = Block( + ParallelBlock({"cross": CrossBlock.with_depth(depth), "deep": deep_block}), + MaybeAgg(Concat()), + ) + + super().__init__(input_block, *cross_network, output_block) diff --git a/merlin/models/torch/outputs/base.py b/merlin/models/torch/outputs/base.py index 8ff5f15e9c..8f3b78acb0 100644 --- a/merlin/models/torch/outputs/base.py +++ b/merlin/models/torch/outputs/base.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect from copy import deepcopy from typing import Optional, Sequence @@ -21,7 +22,7 @@ from torchmetrics import Metric from merlin.models.torch.block import Block -from merlin.schema import ColumnSchema, Schema +from merlin.models.torch.transforms.bias import LogitsTemperatureScaler class ModelOutput(Block): @@ -47,12 +48,13 @@ class ModelOutput(Block): Parameters ---------- - schema: Optional[ColumnSchema] - The schema defining the column properties. loss: nn.Module The loss function used for training. metrics: Sequence[Metric] The metrics used for evaluation. + logits_temperature: float, optional + Parameter used to reduce model overconfidence, so that logits / T. + by default 1.0 name: Optional[str] The name of the model output. """ @@ -60,9 +62,9 @@ class ModelOutput(Block): def __init__( self, *module: nn.Module, - schema: Optional[ColumnSchema] = None, loss: Optional[nn.Module] = None, - metrics: Sequence[Metric] = (), + metrics: Optional[Sequence[Metric]] = None, + logits_temperature: float = 1.0, name: Optional[str] = None, ): """Initializes a ModelOutput object.""" @@ -70,21 +72,10 @@ def __init__( self.loss = loss self.metrics = metrics - self.output_schema: Schema = Schema() - if schema: - self.setup_schema(schema) self.create_target_buffer() - - def setup_schema(self, schema: Optional[ColumnSchema]): - """Set up the schema for the output. - - Parameters - ---------- - schema: ColumnSchema or None - The schema defining the column properties. - """ - self.output_schema = Schema([schema]) + if logits_temperature != 1.0: + self.append(LogitsTemperatureScaler(logits_temperature)) def create_target_buffer(self): self.register_buffer("target", torch.zeros(1, dtype=torch.float32)) @@ -103,18 +94,24 @@ def eval(self): return self.train(False) def copy(self): - metrics = self.metrics + metrics = deepcopy(self.metrics) self.metrics = [] output = deepcopy(self) copied_metrics = [] for metric in metrics: - m = metric.__class__() + params = inspect.signature(metric.__class__.__init__).parameters + kwargs = {} + for arg_name, arg_value in params.items(): + if arg_name in metric.__dict__: + kwargs[arg_name] = metric.__dict__[arg_name] + m = metric.__class__(**kwargs) m.load_state_dict(metric.state_dict()) copied_metrics.append(m) self.metrics = metrics output.metrics = copied_metrics + output.loss = deepcopy(self.loss) return output diff --git a/merlin/models/torch/outputs/classification.py b/merlin/models/torch/outputs/classification.py index 2ca36143f7..b4123385e8 100644 --- a/merlin/models/torch/outputs/classification.py +++ b/merlin/models/torch/outputs/classification.py @@ -13,14 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Optional, Sequence, Union +import inspect +from typing import List, Optional, Sequence, Type, Union +import torch +import torchmetrics as tm from torch import nn -from torchmetrics import AUROC, Accuracy, Metric, Precision, Recall import merlin.dtypes as md +from merlin.models.torch import schema +from merlin.models.torch.batch import Batch +from merlin.models.torch.inputs.embedding import EmbeddingTable from merlin.models.torch.outputs.base import ModelOutput -from merlin.schema import ColumnSchema, Schema +from merlin.schema import ColumnSchema, Schema, Tags class BinaryOutput(ModelOutput): @@ -28,7 +33,7 @@ class BinaryOutput(ModelOutput): Parameters ---------- - schema: Optional[ColumnSchema]) + schema: Union[ColumnSchema, Schema], optional The schema defining the column properties. Default is None. loss: nn.Module The loss function used for training. Default is nn.BCEWithLogitsLoss(). @@ -36,25 +41,30 @@ class BinaryOutput(ModelOutput): The metrics used for evaluation. Default includes Accuracy, AUROC, Precision, and Recall. """ - DEFAULT_LOSS_CLS = nn.BCEWithLogitsLoss - DEFAULT_METRICS_CLS = (Accuracy, AUROC, Precision, Recall) + DEFAULT_LOSS_CLS = nn.BCELoss + DEFAULT_METRICS_CLS = (tm.Accuracy, tm.AUROC, tm.Precision, tm.Recall) def __init__( self, schema: Optional[ColumnSchema] = None, loss: Optional[nn.Module] = None, - metrics: Sequence[Metric] = (), + metrics: Sequence[tm.Metric] = (), ): """Initializes a BinaryOutput object.""" super().__init__( nn.LazyLinear(1), nn.Sigmoid(), - schema=schema, loss=loss or self.DEFAULT_LOSS_CLS(), metrics=metrics or [m(task="binary") for m in self.DEFAULT_METRICS_CLS], ) + if schema: + self.initialize_from_schema(schema) + self._initialized_from_schema = True - def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]): + if not self.metrics: + self.metrics = self.default_metrics() + + def initialize_from_schema(self, target: Optional[Union[ColumnSchema, Schema]]): """Set up the schema for the output. Parameters @@ -75,3 +85,391 @@ def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]): ) self.output_schema = Schema([_target]) + + @classmethod + def schema_selection(cls, schema: Schema) -> Schema: + """Returns a schema containing all binary targets.""" + output = Schema() + output += schema.select_by_tag([Tags.BINARY_CLASSIFICATION, Tags.BINARY]) + for col in schema.select_by_tag([Tags.CATEGORICAL]): + if col.int_domain and col.int_domain.max == 1: + output += Schema([col]) + + return output + + +class CategoricalOutput(ModelOutput): + """ + A prediction block for categorical targets. + + Parameters + ---------- + schema: Union[ColumnSchema, Schema], optional + The schema defining the column properties. Default is None. + loss : nn.Module, optional + The loss function to use for the output model, defaults to + torch.nn.CrossEntropyLoss. + metrics : Optional[Sequence[Metric]], optional + The metrics to evaluate the model output. + logits_temperature: float, optional + Parameter used to reduce model overconfidence, so that logits / T. + by default 1.0 + """ + + DEFAULT_LOSS_CLS = nn.CrossEntropyLoss + DEFAULT_METRICS_CLS = ( + tm.RetrievalHitRate, + tm.RetrievalNormalizedDCG, + tm.RetrievalPrecision, + tm.RetrievalRecall, + ) + DEFAULT_K = (5,) + + def __init__( + self, + schema: Optional[Union[ColumnSchema, Schema]] = None, + loss: Optional[nn.Module] = None, + metrics: Optional[Sequence[tm.Metric]] = None, + logits_temperature: float = 1.0, + ): + super().__init__( + loss=loss or self.DEFAULT_LOSS_CLS(), + metrics=metrics or create_retrieval_metrics(self.DEFAULT_METRICS_CLS, self.DEFAULT_K), + logits_temperature=logits_temperature, + ) + + if schema: + self.initialize_from_schema(schema) + self._initialized_from_schema = True + + @classmethod + def with_weight_tying( + cls, + block: nn.Module, + selection: Optional[schema.Selection] = None, + loss: nn.Module = nn.CrossEntropyLoss(), + metrics: Optional[Sequence[tm.Metric]] = None, + logits_temperature: float = 1.0, + ) -> "CategoricalOutput": + self = cls(loss=loss, metrics=metrics, logits_temperature=logits_temperature) + self = self.tie_weights(block, selection) + if not self.metrics: + self.metrics = self.default_metrics(self.num_classes) + + return self + + def tie_weights( + self, block: nn.Module, selection: Optional[schema.Selection] = None + ) -> "CategoricalOutput": + prediction = EmbeddingTablePrediction.with_weight_tying(block, selection) + self.num_classes = prediction.num_classes + if self: + self[0] = prediction + else: + self.prepend(prediction) + + return self + + def initialize_from_schema(self, target: Optional[Union[ColumnSchema, Schema]]): + """Set up the schema for the output. + + Parameters + ---------- + target: Optional[ColumnSchema] + The schema defining the column properties. + """ + if not isinstance(target, (ColumnSchema, Schema)): + raise ValueError(f"Target must be a ColumnSchema or Schema, got {target}.") + + if isinstance(target, Schema): + if len(target) != 1: + raise ValueError("Schema must contain exactly one column.") + + target = target.first + + to_call = CategoricalTarget(target) + self.num_classes = to_call.num_classes + self.prepend(to_call) + + @classmethod + def schema_selection(cls, schema: Schema) -> Schema: + """Returns a schema containing all categorical targets.""" + output = Schema() + for col in schema.select_by_tag([Tags.CATEGORICAL]): + if col.int_domain and col.int_domain.max > 1: + output += Schema([col]) + + return output + + +class CategoricalTarget(nn.Module): + """Prediction of a categorical feature. + + Parameters + -------------- + feature: Union[ColumnSchema, Schema], optional + Schema of the column being targeted. The schema must contain an + 'int_domain' specifying the maximum integer value representing the + categorical classes. + activation: callable, optional + Activation function to be applied to the output of the linear layer. + If None, no activation function is applied. + bias: bool, default=True + If set to False, the layer will not learn an additive bias. + + Returns + --------- + torch.Tensor + The tensor output of the forward method. + """ + + def __init__( + self, + feature: Optional[Union[Schema, ColumnSchema]] = None, + activation=None, + bias: bool = True, + ): + super().__init__() + + if isinstance(feature, Schema): + assert len(feature) == 1, "Schema can have max 1 feature" + col_schema = feature.first + else: + col_schema = feature + + self.target_name = col_schema.name + self.num_classes = col_schema.int_domain.max + 1 + self.output_schema = categorical_output_schema(col_schema, self.num_classes) + + self.linear = nn.LazyLinear(self.num_classes, bias=bias) + self.activation = activation + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Computes the forward pass of the module and applies the activation function if present. + + Parameters + -------------- + inputs: torch.Tensor + Input tensor for the forward pass. + + Returns + --------- + torch.Tensor + Output tensor from the forward pass of the model. + """ + output = self.linear(inputs) + if self.activation is not None: + output = self.activation(output) + + return output + + def embedding_lookup(self, ids: torch.Tensor) -> torch.Tensor: + """ + Selects the embeddings for the given indices. + + Parameters + -------------- + ids: torch.Tensor + Tensor containing indices for which embeddings are to be returned. + + Returns + --------- + torch.Tensor + The corresponding embeddings. + """ + return torch.index_select(self.embeddings(), 1, ids).t() + + def embeddings(self) -> nn.Parameter: + """ + Returns the embeddings from the weight matrix. + + Returns + --------- + nn.Parameter + The embeddings. + """ + return self.linear.weight.t() + + def should_apply_contrastive(self, batch: Optional[Batch]) -> bool: + if batch is not None and batch.targets and self.training: + return True + + return False + + +class EmbeddingTablePrediction(nn.Module): + """Prediction of a categorical feature using weight-sharing [1] with an embedding table. + + Parameters + ---------- + table : EmbeddingTable + The embedding table to use as the weight matrix. + + References: + ---------- + [1] Hakan Inan, Khashayar Khosravi, and Richard Socher. 2016. Tying word vectors + and word classifiers: A loss framework for language modeling. arXiv preprint + arXiv:1611.01462 (2016). + """ + + def __init__(self, table: EmbeddingTable, selection: Optional[schema.Selection] = None): + super().__init__() + self.table = table + if len(table.domains) > 1: + if not selection: + raise ValueError( + f"Table {table} has multiple columns. ", + "Must specify selection to choose column.", + ) + self.add_selection(selection) + else: + self.num_classes = table.num_embeddings + self.col_schema = table.input_schema.first + self.col_name = self.col_schema.name + self.target_name = self.col_name + self.bias = nn.Parameter( + torch.zeros(self.num_classes, dtype=torch.float32, device=self.embeddings().device) + ) + self.output_schema = categorical_output_schema(self.col_schema, self.num_classes) + + @classmethod + def with_weight_tying( + cls, + block: nn.Module, + selection: Optional[schema.Selection] = None, + ) -> "EmbeddingTablePrediction": + if isinstance(block, EmbeddingTable): + table = block + else: + if not selection: + raise ValueError( + "Must specify a `selection` when providing a block that isn't a table." + ) + + try: + selected = schema.select(block, selection) + table = selected.leaf() + except Exception as e: + raise ValueError("Could not find embedding table in block.") from e + + return cls(table, selection) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Forward pass of the model using input tensor. + + Parameters + ---------- + inputs : torch.Tensor + Input tensor for the forward pass. + + Returns + ---------- + torch.Tensor + Output tensor of the forward pass. + """ + return nn.functional.linear(inputs, self.embeddings(), self.bias) + + def add_selection(self, selection: schema.Selection): + selected = schema.select(self.table.input_schema, selection) + if not len(selected) == 1: + raise ValueError("Schema must contain exactly one column. ", f"got: {selected}") + self.col_schema = selected.first + self.col_name = self.col_schema.name + self.num_classes = self.col_schema.int_domain.max + 1 + self.output_schema = categorical_output_schema(self.col_schema, self.num_classes) + self.target_name = self.col_name + + return self + + def embeddings(self) -> nn.Parameter: + """Fetch the weight matrix from the embedding table. + + Returns + ---------- + nn.Parameter + Weight matrix from the embedding table. + """ + if len(self.table.domains) > 1: + return self.table.feature_weights(self.col_name) + + return self.table.table.weight + + def embedding_lookup(self, inputs: torch.Tensor) -> torch.Tensor: + """Fetch the embeddings for given indices from the embedding table. + + Parameters + ---------- + ids : torch.Tensor + Tensor containing indices for which embeddings are to be returned. + + Returns + ---------- + torch.Tensor + The corresponding embeddings. + """ + return self.table({self.col_name: inputs})[self.col_name] + + def should_apply_contrastive(self, batch: Optional[Batch]) -> bool: + if batch is not None and batch.targets and self.training: + return True + + return False + + +def categorical_output_schema(target: ColumnSchema, num_classes: int) -> Schema: + """Return the output schema given the target column schema.""" + _target = target.with_dtype(md.float32) + _target = _target.with_properties( + {"domain": {"min": 0, "max": 1, "name": _target.name}}, + ) + if "value_count" not in target.properties: + _target = _target.with_properties( + {"value_count": {"min": num_classes, "max": num_classes}}, + ) + + return Schema([_target]) + + +def create_retrieval_metrics( + metrics: Sequence[Type[tm.Metric]], ks: Sequence[int] +) -> List[tm.Metric]: + """ + Create a list of retrieval metrics given metric types and a list of integers. + For each integer in `ks`, a metric is created for each type in `metrics`. + + Parameters + ---------- + metrics : Sequence[Type[tm.Metric]] + The types of metrics to create. Each type should be a callable that + accepts a single integer parameter `k` to instantiate a new metric. + ks : Sequence[int] + A list of integers to use as the `k` or `top_k` parameter when creating each metric. + + Returns + ------- + List[tm.Metric] + A list of metrics. The length of the list is equal to the product of + the lengths of `metrics` and `ks`. The metrics are ordered first by + the values in `ks`, then by the order in `metrics`. + """ + + outputs = [] + + for k in ks: + for metric in metrics: + # check the parameters of the callable metric + params = inspect.signature(metric).parameters + + # the argument name could be 'k' or 'top_k' + arg_name = "top_k" if "top_k" in params else "k" if "k" in params else None + + if arg_name is not None: + outputs.append(metric(**{arg_name: k})) + else: + raise ValueError( + "Expected a callable that accepts either ", + f"a 'k' or a 'top_k' parameter, but got {metric}", + ) + + return outputs diff --git a/merlin/models/torch/outputs/contrastive.py b/merlin/models/torch/outputs/contrastive.py new file mode 100644 index 0000000000..733a3d69b4 --- /dev/null +++ b/merlin/models/torch/outputs/contrastive.py @@ -0,0 +1,476 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict, Final, Optional, Sequence, Tuple, Union + +import torch +import torchmetrics as tm +from torch import nn + +from merlin.models.torch import schema +from merlin.models.torch.batch import Batch +from merlin.models.torch.block import Block +from merlin.models.torch.outputs import sampling # noqa: F401 +from merlin.models.torch.outputs.base import ModelOutput +from merlin.models.torch.outputs.classification import ( + CategoricalOutput, + CategoricalTarget, + EmbeddingTablePrediction, + create_retrieval_metrics, +) +from merlin.models.utils.constants import MIN_FLOAT +from merlin.schema import ColumnSchema, Schema + + +class ContrastiveOutput(ModelOutput): + """ + A prediction block for a contrastive output. + + Parameters + ---------- + schema: Union[ColumnSchema, Schema], optional + The schema defining the column properties. Default is None. + loss : nn.Module, optional + The loss function to use for the output model, defaults to + torch.nn.CrossEntropyLoss. + metrics : Optional[Sequence[Metric]], optional + The metrics to evaluate the model output. + logits_temperature: float, optional + Parameter used to reduce model overconfidence, so that logits / T. + by default 1.0 + """ + + def __init__( + self, + schema: Optional[Union[ColumnSchema, Schema]] = None, + negative_sampling="in-batch", + loss: nn.Module = nn.CrossEntropyLoss(), + metrics: Optional[Sequence[tm.Metric]] = None, + downscore_false_negatives: bool = True, + false_negative_score: float = MIN_FLOAT, + logits_temperature: float = 1.0, + ): + if not metrics: + metrics = create_retrieval_metrics( + CategoricalOutput.DEFAULT_METRICS_CLS, CategoricalOutput.DEFAULT_K + ) + + super().__init__( + loss=loss, + metrics=metrics, + logits_temperature=logits_temperature, + ) + + if schema: + self.initialize_from_schema(schema) + + self.init_hook_handle = self.register_forward_pre_hook(self.initialize) + if not torch.jit.is_scripting(): + if isinstance(negative_sampling, str): + negative_sampling = [negative_sampling] + self.negative_sampling = nn.ModuleList((Block.parse(s) for s in negative_sampling)) + self.downscore_false_negatives = downscore_false_negatives + self.false_negative_score = false_negative_score + + @classmethod + def with_weight_tying( + cls, + block: nn.Module, + selection: Optional[schema.Selection] = None, + negative_sampling="popularity", + loss: nn.Module = nn.CrossEntropyLoss(), + metrics: Optional[Sequence[tm.Metric]] = None, + logits_temperature: float = 1.0, + downscore_false_negatives: bool = True, + false_negative_score: float = MIN_FLOAT, + ) -> "CategoricalOutput": + self = cls( + loss=loss, + metrics=metrics, + logits_temperature=logits_temperature, + negative_sampling=negative_sampling, + downscore_false_negatives=downscore_false_negatives, + false_negative_score=false_negative_score, + ) + self = self.tie_weights(block, selection) + + return self + + def tie_weights( + self, block: nn.Module, selection: Optional[schema.Selection] = None + ) -> "CategoricalOutput": + prediction = EmbeddingTablePrediction.with_weight_tying(block, selection) + self.num_classes = prediction.num_classes + if self: + self[0] = prediction + else: + self.prepend(prediction) + self.set_to_call(prediction) + + return self + + def initialize_from_schema(self, target: Union[ColumnSchema, Schema]): + """Set up the schema for the output. + + Parameters + ---------- + target: Optional[ColumnSchema] + The schema defining the column properties. + """ + if not isinstance(target, (ColumnSchema, Schema)): + raise ValueError(f"Target must be a ColumnSchema or Schema, got {target}.") + + if isinstance(target, Schema) and len(target) > 1: + if len(target) > 2: + raise ValueError(f"Schema must have one or two column(s), got {target}.") + + self.set_to_call(DotProduct(*target.column_names)) + else: + if isinstance(target, Schema): + target = target.first + self.set_to_call(CategoricalTarget(target)) + + self.prepend(self.to_call) + + def initialize(self, module, inputs): + if torch.jit.isinstance(inputs[0], Dict[str, torch.Tensor]): + for i in range(len(module)): + if isinstance(module[i], CategoricalTarget): + module[i] = DotProduct(module[i].target_name) + self.set_to_call(module[i]) + elif isinstance(self.to_call, CategoricalTarget): + # Make sure CategoticalTarget is initialized + outputs = inputs[0] + for to_call in module.values: + outputs = to_call(outputs) + + self.init_hook_handle.remove() # Clear hook once block is initialized + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ) -> torch.Tensor: + outputs = inputs + for module in self.values: + should_apply_contrastive = False + + can_contrast = ( + not torch.jit.is_scripting() + and hasattr(module, "requires_batch") + and hasattr(module.to_wrap, "should_apply_contrastive") + ) + if can_contrast: + should_apply_contrastive = module.to_wrap.should_apply_contrastive(batch) + + if should_apply_contrastive: + if batch is None: + raise ValueError("Contrastive output requires a batch") + _batch = batch if batch is not None else Batch({}) + + target_name = module.to_wrap.target_name + + if torch.jit.isinstance(outputs, Dict[str, torch.Tensor]) and hasattr( + module.to_wrap, "get_query_name" + ): + query_name: str = module.to_wrap.get_query_name(outputs) + outputs = self.contrastive_dot_product( + outputs, + _batch, + target_name=target_name, + query_name=query_name, + ) + elif torch.jit.isinstance(outputs, torch.Tensor): + outputs = self.contrastive_lookup( + outputs, + _batch, + target_name=target_name, + ) + else: + raise RuntimeError("Couldn't apply contrastive output") + else: + outputs = module(outputs, batch=batch) + + return outputs + + @torch.jit.unused + def contrastive_dot_product( + self, + inputs: Dict[str, torch.Tensor], + batch: Batch, + target_name: str, + query_name: str, + ) -> torch.Tensor: + query = inputs[query_name] + positive = inputs[target_name] + positive_id = None + if target_name in batch.features: + positive_id = batch.features[target_name] + + negative, negative_id = self.sample_negatives(positive, positive_id=positive_id) + + return self.contrastive_outputs( + query, + positive, + negative, + positive_id=positive_id, + negative_id=negative_id, + ) + + @torch.jit.unused + def contrastive_lookup( + self, + inputs: torch.Tensor, + batch: Batch, + target_name: str, + ) -> torch.Tensor: + query = inputs + if len(batch.targets) == 1: + positive_id = batch.target() + else: + positive_id = batch.targets[target_name] + + if not hasattr(self.to_call, "embedding_lookup"): + raise ValueError("Couldn't infer positive embedding") + positive = self.embedding_lookup(positive_id) + + negative, negative_id = self.sample_negatives(positive, positive_id=positive_id) + + return self.contrastive_outputs( + query, + positive, + negative, + positive_id=positive_id, + negative_id=negative_id, + ) + + @torch.jit.unused + def sample_negatives( + self, positive: torch.Tensor, positive_id: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Samples negative examples for the given positive tensor. + + Args: + positive (torch.Tensor): Tensor containing positive samples. + positive_id (torch.Tensor, optional): Tensor containing the IDs + of positive samples. Defaults to None. + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing the negative samples + tensor and the IDs of negative samples. + """ + outputs, ids = [], [] + for sampler in self.negative_sampling: + _positive_id: torch.Tensor = positive_id if positive_id is not None else torch.tensor(1) + negative, negative_id = sampler(positive, positive_id=_positive_id) + if negative.shape[0] == negative_id.shape[0]: + ids.append(negative_id) + outputs.append(negative) + + if ids: + if len(outputs) != len(ids): + raise RuntimeError("The number of negative samples and ids must be the same") + + negative_tensor = torch.cat(outputs, dim=1) if len(outputs) > 1 else outputs[0] + id_tensor = torch.tensor(1) + if ids: + id_tensor = torch.cat(ids, dim=0) if len(ids) > 1 else ids[0] + + return negative_tensor, id_tensor + + @torch.jit.unused + def contrastive_outputs( + self, + query: torch.Tensor, + positive: torch.Tensor, + negative: torch.Tensor, + positive_id: Optional[torch.Tensor] = None, + negative_id: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Computes the contrastive outputs given + the query tensor, positive tensor, and negative tensor. + + Args: + query (torch.Tensor): Tensor containing the query data. + positive (torch.Tensor): Tensor containing the positive data. + negative (torch.Tensor): Tensor containing the negative data. + positive_id (torch.Tensor, optional): Tensor containing the IDs of positive samples. + Defaults to None. + negative_id (torch.Tensor, optional): Tensor containing the IDs of negative samples. + Defaults to None. + + Returns: + torch.Tensor: Tensor containing the contrastive outputs. + + Note, the transformed-targets are stored in + the `target` attribute (which is a buffer). + """ + + # Dot-product for the positive-scores + positive_scores = torch.sum(query * positive, dim=-1, keepdim=True) + negative_scores = torch.matmul(query, negative.t()) + + if self.downscore_false_negatives: + if ( + positive_id is None + or negative_id is None + or positive.shape[0] != positive_id.shape[-1] + or negative.shape[0] != negative_id.shape[-1] + ): + raise RuntimeError( + "Both positive_id and negative_id must be provided " + "when downscore_false_negatives is True" + ) + negative_scores, _ = rescore_false_negatives( + positive_id, negative_id, negative_scores, self.false_negative_score + ) + + if len(negative_scores.shape) + 1 == len(positive_scores.shape): + negative_scores = negative_scores.unsqueeze(0) + output = torch.cat([positive_scores, negative_scores], dim=-1) + + # Ensure the output is always float32 to avoid numerical instabilities + output = output.to(torch.float32) + + batch_size = output.shape[0] + num_negatives = output.shape[1] - 1 + + self.target = torch.cat( + [ + torch.ones(batch_size, 1, dtype=output.dtype), + torch.zeros(batch_size, num_negatives, dtype=output.dtype), + ], + dim=1, + ) + + return output + + def embedding_lookup(self, ids: torch.Tensor) -> torch.Tensor: + return self.to_call.embedding_lookup(torch.squeeze(ids)) + + def set_to_call(self, to_call: nn.Module): + self.to_call = to_call + if not torch.jit.is_scripting() and hasattr(self, "negative_sampling"): + for sampler in self.negative_sampling: + if hasattr(sampler, "set_to_call"): + sampler.set_to_call(to_call) + + +@Block.registry.register("dot-product") +class DotProduct(nn.Module): + """Dot-product between queries & candidates. + + Parameters: + ----------- + query_name : str, optional + Identify query tower for query/user embeddings, by default 'query' + candidate_name : str, optional + Identify item tower for item embeddings, by default 'candidate' + """ + + is_output_module: Final[bool] = True + + def __init__( + self, + candidate_name: str = "candidate", + query_name: Optional[str] = None, + ): + super().__init__() + self.query_name: Optional[str] = query_name + self.candidate_name = candidate_name + self.target_name = self.candidate_name + + def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + if len(inputs) < 2: + raise RuntimeError(f"DotProduct requires at least two inputs, got: {inputs}") + + candidate = inputs[self.candidate_name] + + if len(inputs) == 2: + query_name = self.get_query_name(inputs) + elif self.query_name is not None: + query_name = self.query_name + else: + raise RuntimeError( + "DotProduct requires query_name to be set when more than ", + f"two inputs are provided, got: {inputs}", + ) + query = inputs[query_name] + + # Alternative is: torch.einsum('...i,...i->...', query, item) + return torch.sum(query * candidate, dim=-1, keepdim=True) + + def get_query_name(self, inputs: Dict[str, torch.Tensor]) -> str: + if self.query_name is None: + for key in inputs: + if key != self.candidate_name: + return key + else: + return self.query_name + + raise RuntimeError( + "DotProduct requires query_name to be set when more than two inputs are provided" + ) + + def should_apply_contrastive(self, batch: Optional[Batch]) -> bool: + return self.training + + +def rescore_false_negatives( + positive_item_ids: torch.Tensor, + neg_samples_item_ids: torch.Tensor, + negative_scores: torch.Tensor, + false_negatives_score: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Zeroes the logits of accidental negatives. + + Parameters + ---------- + positive_item_ids : torch.Tensor + A tensor containing the IDs of the positive items. + neg_samples_item_ids : torch.Tensor + A tensor containing the IDs of the negative samples. + negative_scores : torch.Tensor + A tensor containing the scores of the negative samples. + false_negatives_score : float + The score to assign to false negatives (accidental hits). + + Returns + ------- + torch.Tensor + A tensor containing the rescored negative scores. + torch.Tensor + A tensor containing a mask representing valid negatives. + """ + + # Removing dimensions of size 1 from the shape of the item ids, if applicable + positive_item_ids = torch.squeeze(positive_item_ids).to(neg_samples_item_ids.dtype) + neg_samples_item_ids = torch.squeeze(neg_samples_item_ids) + + # Reshapes positive and negative ids so that false_negatives_mask matches the scores shape + false_negatives_mask = torch.eq( + positive_item_ids.unsqueeze(-1), neg_samples_item_ids.unsqueeze(0) + ) + + # Setting a very small value for false negatives (accidental hits) so that it has + # negligible effect on the loss functions + negative_scores = torch.where( + false_negatives_mask, + torch.ones_like(negative_scores) * false_negatives_score, + negative_scores, + ) + + valid_negatives_mask = ~false_negatives_mask + + return torch.squeeze(negative_scores), valid_negatives_mask diff --git a/merlin/models/torch/outputs/regression.py b/merlin/models/torch/outputs/regression.py index e3b2f97b09..341b0cecc6 100644 --- a/merlin/models/torch/outputs/regression.py +++ b/merlin/models/torch/outputs/regression.py @@ -48,12 +48,17 @@ def __init__( """Initializes a RegressionOutput object.""" super().__init__( nn.LazyLinear(1), - schema=schema, loss=loss or self.DEFAULT_LOSS_CLS(), metrics=metrics or [m() for m in self.DEFAULT_METRICS_CLS], ) + if schema: + self.initialize_from_schema(schema) + self._initialized_from_schema = True - def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]): + if not self.metrics: + self.metrics = self.default_metrics() + + def initialize_from_schema(self, target: Optional[Union[ColumnSchema, Schema]]): """Set up the schema for the output. Parameters @@ -62,7 +67,7 @@ def setup_schema(self, target: Optional[Union[ColumnSchema, Schema]]): The schema defining the column properties. """ if isinstance(target, Schema): - if len(target) != 1: + if len(target) > 1: raise ValueError("Schema must contain exactly one column.") target = target.first diff --git a/merlin/models/torch/outputs/sampling/__init__.py b/merlin/models/torch/outputs/sampling/__init__.py new file mode 100644 index 0000000000..535d01fdf5 --- /dev/null +++ b/merlin/models/torch/outputs/sampling/__init__.py @@ -0,0 +1,7 @@ +from merlin.models.torch.outputs.sampling.in_batch import InBatchNegativeSampler +from merlin.models.torch.outputs.sampling.popularity import PopularityBasedSampler + +__all__ = [ + "InBatchNegativeSampler", + "PopularityBasedSampler", +] diff --git a/merlin/models/torch/outputs/sampling/in_batch.py b/merlin/models/torch/outputs/sampling/in_batch.py new file mode 100644 index 0000000000..2014606a23 --- /dev/null +++ b/merlin/models/torch/outputs/sampling/in_batch.py @@ -0,0 +1,40 @@ +from typing import Tuple + +import torch +from torch import nn + +from merlin.models.torch.block import registry + + +@registry.register("in-batch") +class InBatchNegativeSampler(nn.Module): + """PyTorch module that performs in-batch negative sampling.""" + + def __init__(self): + super().__init__() + self.register_buffer("negative", torch.zeros(1, dtype=torch.float32)) + self.register_buffer("negative_id", torch.zeros(1)) + + def forward( + self, positive: torch.Tensor, positive_id: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Doing in-batch negative-sampling. + + positive & positive_id are registered as non-persistent buffers + + Args: + positive (torch.Tensor): Tensor containing positive samples. + positive_id (torch.Tensor, optional): Tensor containing the IDs of + positive samples. Defaults to None. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the positive + samples tensor and the positive samples IDs tensor. + """ + if self.training: + if torch.jit.isinstance(positive, torch.Tensor): + self.negative = positive + if torch.jit.isinstance(positive_id, torch.Tensor): + self.negative_id = positive_id + + return positive, positive_id diff --git a/merlin/models/torch/outputs/sampling/popularity.py b/merlin/models/torch/outputs/sampling/popularity.py new file mode 100644 index 0000000000..38e90c51aa --- /dev/null +++ b/merlin/models/torch/outputs/sampling/popularity.py @@ -0,0 +1,254 @@ +from typing import Optional + +import torch +from torch import nn + +from merlin.models.torch.block import registry + + +class LogUniformSampler(torch.nn.Module): + """ + LogUniformSampler samples negative samples based on a log-uniform distribution. + `P(class) = (log(class + 2) - log(class + 1)) / log(max_id + 1)` + + This implementation is based on to: + https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/utils/log_uniform_sampler.py + TensorFlow Reference: + https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py + + LogUniformSampler assumes item ids are sorted decreasingly by their frequency. + + if `unique_sampling==True`, then only unique sampled items will be returned. + The actual # samples will vary from run to run if `unique_sampling==True`, + as sampling without replacement (`torch.multinomial(..., replacement=False)`) is slow, + so we use `torch.multinomial(..., replacement=True).unique()` which doesn't guarantee + the same number of unique sampled items. You can try to increase + n_samples_multiplier_before_unique to increase the chances to have more + unique samples in that case. + + Parameters + ---------- + max_n_samples : int + The maximum desired number of negative samples. The number of samples might be + smaller than that if `unique_sampling==True`, as explained above. + max_id : int + The maximum value of the range for the log-uniform distribution. + min_id : Optional[int] + The minimum value of the range for the log-uniform sampling. By default 0. + unique_sampling : bool + Whether to return unique samples. By default True + n_samples_multiplier_before_unique : int + If unique_sampling=True, it is not guaranteed that the number of returned + samples will be equal to max_n_samples, as explained above. + You can increase n_samples_multiplier_before_unique to maximize + chances that a larger number of unique samples is returned. + """ + + def __init__( + self, + max_n_samples: int, + max_id: int, + min_id: Optional[int] = 0, + unique_sampling: bool = True, + n_samples_multiplier_before_unique: int = 2, + ): + super().__init__() + + if max_id <= 0: + raise ValueError("max_id must be a positive integer.") + if max_n_samples <= 0: + raise ValueError("n_sample must be a positive integer.") + + self.max_id = max_id + self.unique_sampling = unique_sampling + self.max_n_samples = max_n_samples + self.n_sample = max_n_samples + if self.unique_sampling: + self.n_sample = int(self.n_sample * n_samples_multiplier_before_unique) + + with torch.no_grad(): + dist = self.get_log_uniform_distr(max_id, min_id) + self.register_buffer("dist", dist) + unique_sampling_dist = self.get_unique_sampling_distr(dist, self.n_sample) + self.register_buffer("unique_sampling_dist", unique_sampling_dist) + + def get_log_uniform_distr(self, max_id: int, min_id: int = 0) -> torch.Tensor: + """Approximates the items frequency distribution with log-uniform probability distribution + with P(class) = (log(class + 2) - log(class + 1)) / log(max_id + 1). + It assumes item ids are sorted decreasingly by their frequency. + + Parameters + ---------- + max_id : int + Maximum discrete value for sampling (e.g. cardinality of the item id) + + Returns + ------- + torch.Tensor + Returns the log uniform probability distribution + """ + log_indices = torch.arange(1.0, max_id - min_id + 2.0, 1.0).log_() + probs = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] + if min_id > 0: + probs = torch.cat( + [torch.zeros([min_id], dtype=probs.dtype), probs], axis=0 + ) # type: ignore + return probs + + def get_unique_sampling_distr(self, dist, n_sample): + """Returns the probability that each item is sampled at least once + given the specified number of trials. This is meant to be used when + self.unique_sampling == True. + That probability can be approximated by by 1 - (1 - p)^n + and we use a numerically stable version: -expm1(num_tries * log1p(-p)) + """ + return (-(-dist.double().log1p_() * n_sample).expm1_()).float() + + @torch.jit.unused + def sample(self, labels: torch.Tensor): + """Sample negative samples and calculate their probabilities. + + If `unique_sampling==True`, then only unique sampled items will be returned. + The actual # samples will vary from run to run if `unique_sampling==True`, + as sampling without replacement (`torch.multinomial(..., replacement=False)`) is slow, + so we use `torch.multinomial(..., replacement=True).unique()` + which doesn't guarantee the same number of unique sampled items. + You can try to increase n_samples_multiplier_before_unique + to increase the chances to have more unique samples in that case. + + Parameters + ---------- + labels : torch.Tensor, dtype=torch.long, shape=(batch_size,) + The input labels for which negative samples should be generated. + + Returns + ------- + neg_samples : torch.Tensor, dtype=torch.long, shape=(n_samples,) + The unique negative samples drawn from the log-uniform distribution. + true_probs : torch.Tensor, dtype=torch.float32, shape=(batch_size,) + The probabilities of the input labels according + to the log-uniform distribution (depends on self.unique_sampling choice). + samp_log_probs : torch.Tensor, dtype=torch.float32, shape=(n_samples,) + The probabilities of the sampled negatives according + to the log-uniform distribution (depends on self.unique_sampling choice). + """ + + if not torch.is_tensor(labels): + raise TypeError("Labels must be a torch.Tensor.") + if labels.dtype != torch.long: + raise ValueError("Labels must be a tensor of dtype long.") + if labels.dim() > 2 or (labels.dim() == 2 and min(labels.shape) > 1): + raise ValueError( + "Labels must be a 1-dimensional tensor or a 2-dimensional tensor" + "with one of the dimensions equal to 1." + ) + if labels.size(0) == 0: + raise ValueError("Labels must not be an empty tensor.") + if (labels < 0).any() or (labels > self.max_id).any(): + raise ValueError( + "All label values must be within the range [0, max_id], ", + f"got: [{labels.min().item()}, {labels.max().item()}].", + ) + + n_tries = self.n_sample + + with torch.no_grad(): + neg_samples = torch.multinomial( + self.dist, n_tries, replacement=True # type: ignore + ).unique()[: self.max_n_samples] + + device = labels.device + neg_samples = neg_samples.to(device) + + if self.unique_sampling: + dist = self.unique_sampling_dist + else: + dist = self.dist + + true_probs = dist[labels] # type: ignore + samples_probs = dist[neg_samples] # type: ignore + + return neg_samples, true_probs, samples_probs + + +@registry.register("popularity") +class PopularityBasedSampler(nn.Module): + """The PopularityBasedSampler generates negative samples for a positive + input sample based on popularity. + + The class utilizes a LogUniformSampler to draw negative samples from a + log-uniform distribution. The sampler approximates the items' frequency + distribution with log-uniform probability distribution. The sampler assumes + item ids are sorted decreasingly by their frequency. + + Parameters + ---------- + max_num_samples : int, optional + The maximum number of negative samples desired. Default is 10. + unique_sampling : bool, optional + Whether to return unique samples. Default is True. + n_samples_multiplier_before_unique : int, optional + Factor to increase the chances to have more unique samples. Default is 2. + + """ + + def __init__( + self, + max_num_samples: int = 10, + unique_sampling: bool = True, + n_samples_multiplier_before_unique: int = 2, + ): + super().__init__() + self.labels = torch.ones((1, 1), dtype=torch.int64) + self.max_num_samples = max_num_samples + self.negative = self.register_buffer("negative", None) + self.negative_id = self.register_buffer("negative_id", None) + self.unique_sampling = unique_sampling + self.n_samples_multiplier_before_unique = n_samples_multiplier_before_unique + + def forward(self, positive, positive_id=None): + """Computes the forward pass of the PopularityBasedSampler. + + Parameters + ---------- + positive : torch.Tensor + The positive samples. + positive_id : torch.Tensor, optional + The ids of the positive samples. Default is None. + + Returns + ------- + negative : torch.Tensor + The negative samples. + negative_id : torch.Tensor + The ids of the negative samples. + """ + + if torch.jit.is_scripting(): + raise RuntimeError("PopularityBasedSampler is not supported in TorchScript.") + + if not hasattr(self, "to_call"): + raise RuntimeError("PopularityBasedSampler must be called after the model") + + del positive, positive_id + self.negative_id, _, _ = self.sampler.sample(self.labels) + + self.negative = self.to_call.embedding_lookup(torch.squeeze(self.negative_id)) + + return self.negative, self.negative_id + + def set_to_call(self, to_call: nn.Module): + """Set the model that will utilize this sampler. + + Parameters + ---------- + to_call : torch.nn.Module + The model that will utilize this sampler. + """ + self.to_call = to_call + self.sampler = LogUniformSampler( + max_n_samples=self.max_num_samples, + max_id=self.to_call.num_classes, + unique_sampling=self.unique_sampling, + n_samples_multiplier_before_unique=self.n_samples_multiplier_before_unique, + ) diff --git a/merlin/models/torch/outputs/tabular.py b/merlin/models/torch/outputs/tabular.py index 61df85fe00..8510e057bf 100644 --- a/merlin/models/torch/outputs/tabular.py +++ b/merlin/models/torch/outputs/tabular.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Optional, Union -from merlin.models.torch.outputs.classification import BinaryOutput +from merlin.models.torch.outputs.classification import BinaryOutput, CategoricalOutput from merlin.models.torch.outputs.regression import RegressionOutput from merlin.models.torch.router import RouterBlock from merlin.models.torch.schema import Selection, select @@ -49,22 +49,26 @@ class TabularOutputBlock(RouterBlock): def __init__( self, - schema: Schema, + schema: Optional[Schema] = None, init: Optional[Union[str, Initializer]] = None, selection: Optional[Selection] = Tags.TARGET, ): - if selection: - schema = select(schema, selection) - + self.selection = selection + self.init = init super().__init__(schema, prepend_routing_module=False) + + def initialize_from_schema(self, schema: Schema): + if self.selection: + schema = select(schema, self.selection) + super().initialize_from_schema(schema) self.schema: Schema = self.selectable.schema - if init: - if isinstance(init, str): - init = self.initializers.get(init) - if not init: - raise ValueError(f"Initializer {init} not found.") + if self.init: + if isinstance(self.init, str): + self.init = self.initializers.get(self.init) + if not self.init: + raise ValueError(f"Initializer {self.init} not found.") - init(self) + self.init(self) @classmethod def register_init(cls, name: str): @@ -96,9 +100,13 @@ def defaults(block: TabularOutputBlock): This function adds a route for each of the following tags: - Tags.CONTINUOUS/Tags.REGRESSION -> RegressionOutput - Tags.BINARY_CLASSIFICATION/Tags.BINARY -> BinaryOutput + - Tags.MULTI_CLASS_CLASSIFICATION/Tags.CATEGORICAL -> CategoricalOutput Args: block (TabularOutputBlock): The block to initialize. """ - block.add_route_for_each([Tags.CONTINUOUS, Tags.REGRESSION], RegressionOutput()) - block.add_route_for_each([Tags.BINARY_CLASSIFICATION, Tags.BINARY], BinaryOutput()) + block.add_route_for_each([Tags.CONTINUOUS, Tags.REGRESSION], RegressionOutput(), required=False) + block.add_route_for_each(BinaryOutput.schema_selection, BinaryOutput(), required=False) + block.add_route_for_each( + CategoricalOutput.schema_selection, CategoricalOutput(), required=False + ) diff --git a/merlin/models/torch/predict.py b/merlin/models/torch/predict.py index a921dc667c..2c78c6fa67 100644 --- a/merlin/models/torch/predict.py +++ b/merlin/models/torch/predict.py @@ -7,6 +7,8 @@ from merlin.core.dispatch import DataFrameLike, concat, concat_columns from merlin.dataloader.torch import Loader from merlin.io import Dataset +from merlin.models.torch.batch import Batch +from merlin.models.torch.block import BatchBlock, Block from merlin.models.torch.schema import Selection, select from merlin.schema import ColumnSchema, Schema from merlin.table import TensorTable @@ -15,7 +17,125 @@ DFType = TypeVar("DFType", bound=DataFrameLike) -class Encoder: +class EncoderBlock(Block): + """ + A block that runs a `BatchBlock` as a pre-processing step before running the rest. + + This ensures that the batch is created at inference time as well. + + Parameters + ---------- + *module : nn.Module + Variable length argument list of PyTorch modules. + pre : BatchBlock, optional + An instance of BatchBlock class for pre-processing. + If None, an instance of BatchBlock is created. + name : str, optional + A name for the encoder block. + + Raises + ------ + ValueError + If the 'pre' argument is not an instance of BatchBlock. + """ + + def __init__( + self, *module: nn.Module, pre: Optional[BatchBlock] = None, name: Optional[str] = None + ): + super().__init__(*module, name=name) + if isinstance(pre, BatchBlock): + self.pre = pre + elif pre is None: + self.pre = BatchBlock() + else: + raise ValueError(f"Invalid pre: {pre}, must be a BatchBlock") + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ): + _batch: Batch = self.pre(inputs, batch=batch) + + outputs = _batch.inputs() + for block in self.values: + outputs = block(outputs, batch=_batch) + return outputs + + @torch.jit.unused + def encode( + self, + data: Union[Dataset, Loader, Batch], + selection: Optional[Selection] = None, + batch_size=None, + index: Optional[Selection] = None, + unique: bool = True, + ): + """ + Encodes a given data set. + + Parameters + ---------- + data : Union[Dataset, Loader, Batch] + Input data to encode. + selection : Selection, optional + Features to encode. If not provided, all features will be encoded. + batch_size : int, optional + Size of the batch to encode. + index : Selection, optional + Index selection. + unique : bool, optional + If True, only unique entries are returned. + + Returns + ------- + encoded_data : Dataset + Encoded data set. + """ + _dask_encoder = DaskEncoder(self, selection=selection) + + return _dask_encoder(data, batch_size=batch_size, index=index, unique=unique) + + @torch.jit.unused + def predict( + self, + data: Union[Dataset, Loader, Batch], + selection: Optional[Selection] = None, + batch_size=None, + index: Optional[Selection] = None, + prediction_suffix: str = "_prediction", + unique: bool = True, + ): + """ + Encodes a given data set and predicts the output. + All input-features will be present in the output. + + Parameters + ---------- + data : Union[Dataset, Loader, Batch] + Input data to encode. + selection : Selection, optional + Features to encode. If not provided, all features will be encoded. + batch_size : int, optional + Size of the batch to encode. + index : Selection, optional + Index selection. + prediction_suffix : str, optional + The suffix to add to the prediction columns in the output DataFrame. + unique : bool, optional + If True, only unique entries are returned. + + Returns + ------- + predictions : dask_cudf.DataFrame + Predictions of the data set. + """ + _dask_predictor = DaskPredictor( + self, prediction_suffix=prediction_suffix, selection=selection + ) + + return _dask_predictor(data, batch_size=batch_size, index=index, unique=unique) + + +class DaskEncoder: """Encode various forms of data using a specified PyTorch module. Supporting multiple data formats like Datasets, Loaders, DataFrames, @@ -28,7 +148,7 @@ class Encoder: # `selection=Tags.USER` ensures that only the sub-module(s) of the model # that processes features tagged as user is used during encoding. # Additionally, it filters out all other features that aren't tagged as user. - >>> user_encoder = Encoder(model[0], selection=Tags.USER) + >>> user_encoder = DaskEncoder(model[0], selection=Tags.USER) # The index is used in the resulting DataFrame after encoding # Setting unique=True (default value) ensures that any duplicate rows @@ -316,7 +436,7 @@ def reduce(self, left: DFType, right: DFType): return concat([left, right]) -class Predictor(Encoder): +class DaskPredictor(DaskEncoder): """Prediction on various forms of data using a specified PyTorch module. This is especially useful when you want to keep track of both the original data and @@ -326,7 +446,7 @@ class Predictor(Encoder): Example usage:: >>> dataset = Dataset(...) >>> model = mm.TwoTowerModel(dataset.schema) - >>> predictor = Predictor(model) + >>> predictor = DaskPredictor(model) >>> predictions = predictor(dataset, batch_size=128) >>> print(predictions.compute()) user_id user_age item_id item_category click click_prediction diff --git a/merlin/models/torch/router.py b/merlin/models/torch/router.py index 326064c68c..33aa68cc75 100644 --- a/merlin/models/torch/router.py +++ b/merlin/models/torch/router.py @@ -15,6 +15,7 @@ # from copy import deepcopy +from inspect import isclass from typing import Optional from torch import nn @@ -48,19 +49,23 @@ class RouterBlock(ParallelBlock): def __init__(self, selectable: schema.Selectable, prepend_routing_module: bool = True): super().__init__() + self.prepend_routing_module = prepend_routing_module if isinstance(selectable, Schema): - from merlin.models.torch.inputs.select import SelectKeys + self.initialize_from_schema(selectable) + else: + self.selectable: schema.Selectable = selectable - selectable = SelectKeys(selectable) + def initialize_from_schema(self, schema): + from merlin.models.torch.inputs.select import SelectKeys - self.selectable: schema.Selectable = selectable - self.prepend_routing_module = prepend_routing_module + self.selectable = SelectKeys(schema) def add_route( self, selection: schema.Selection, module: Optional[nn.Module] = None, name: Optional[str] = None, + required: bool = True, ) -> "RouterBlock": """Add a new routing path for a given selection. @@ -80,6 +85,8 @@ def add_route( The module to append to the branch after selection. name : str, optional The name of the branch. Default is the name of the selection. + required : bool, optional + Whether the route is required. Default is True. Returns ------- @@ -87,12 +94,18 @@ def add_route( The router block with the new route added. """ + if self.selectable is None: + raise ValueError(f"{self} has nothing to select from, so cannot add route.") + routing_module = schema.select(self.selectable, selection) if not routing_module: + if required: + raise ValueError(f"Selection {selection} not found in {self.selectable}") + return self if module is not None: - schema.setup_schema(module, routing_module.schema) + schema.initialize_from_schema(module, routing_module.schema) if self.prepend_routing_module: if isinstance(module, ParallelBlock): @@ -117,7 +130,11 @@ def add_route( return self def add_route_for_each( - self, selection: schema.Selection, module: nn.Module, shared=False + self, + selection: schema.Selection, + module: nn.Module, + shared=False, + required: bool = True, ) -> "RouterBlock": """Add a new route for each column in a selection. @@ -152,12 +169,14 @@ def add_route_for_each( if shared: col_module = module else: - if hasattr(module, "copy"): + if isclass(module): + col_module = module(col) + elif hasattr(module, "copy"): col_module = module.copy() else: col_module = deepcopy(module) - self.add_route(col, col_module, name=col.name) + self.add_route(col, col_module, name=col.name, required=required) return self diff --git a/merlin/models/torch/schema.py b/merlin/models/torch/schema.py index 937a4a869b..f8a1af2531 100644 --- a/merlin/models/torch/schema.py +++ b/merlin/models/torch/schema.py @@ -21,6 +21,8 @@ from torch import nn from merlin.dispatch.lazy import LazyDispatcher +from merlin.models.torch.batch import Batch +from merlin.models.torch.utils.module_utils import check_batch_arg from merlin.schema import ColumnSchema, Schema, Tags, TagSet Selection = Union[Schema, ColumnSchema, Callable[[Schema], Schema], Tags, TagSet, List[Tags]] @@ -28,6 +30,15 @@ NAMESPACE_TAGS = [Tags.CONTEXT, Tags.USER, Tags.ITEM, Tags.SESSION] +class LazySchemaModuleMixin: + _initialized_from_schema = False + + def initialize_from_schema(self, schema): + if self._initialized_from_schema: + raise RuntimeError("Already initialized this module from a schema") + self._initialized_from_schema = True + + def default_tag_propagation(inputs: Schema, outputs: Schema): if inputs: namespaces = [] @@ -68,6 +79,7 @@ def tensors(self, inputs): def get_schema(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema]) -> Schema: if isinstance(inputs, Schema): return inputs + return self.tensors(inputs) @@ -97,7 +109,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema return super().__call__(module, inputs) except NotImplementedError: raise ValueError( - f"Could not get output schema of {module} " "please call mm.trace_schema first." + f"Could not get output schema of {module} " "please call `mm.schema.trace` first." ) def trace( @@ -127,7 +139,7 @@ def _func(module: nn.Module, input: Schema) -> Schema: def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema: try: - _inputs = input(module) + _inputs = input_schema(module) inputs = _inputs except ValueError: pass @@ -156,7 +168,7 @@ def __call__(self, module: nn.Module, inputs: Optional[Schema] = None) -> Schema return super().__call__(module, inputs) except NotImplementedError: raise ValueError( - f"Could not get output schema of {module} " "please call mm.trace_schema first." + f"Could not get output schema of {module} " "please call `mm.schema.trace` first." ) def trace( @@ -165,7 +177,7 @@ def trace( inputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema], outputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema], ) -> Schema: - _input_schema = input.get_schema(inputs) + _input_schema = input_schema.get_schema(inputs) _output_schema = self.get_schema(outputs) try: @@ -207,8 +219,8 @@ def extract(self, module: nn.Module, selection: Selection, route: nn.Module, nam return fn(module, selection, route, name=name) -input = _InputSchemaDispatch("input_schema") -output = _OutputSchemaDispatch("output_schema") +input_schema = _InputSchemaDispatch("input_schema") +output_schema = _OutputSchemaDispatch("output_schema") select = _SelectDispatch("selection") extract = _ExtractDispatch("extract") @@ -234,19 +246,40 @@ def trace(module: nn.Module, inputs: Union[torch.Tensor, Dict[str, torch.Tensor] """ hooks = [] + batch = kwargs.get("batch") def _hook(mod: nn.Module, inputs: Tuple[torch.Tensor], outputs: torch.Tensor): if not hasattr(mod, "__input_schemas"): mod.__input_schemas = () mod.__output_schemas = () + mod.__feature_schema = None + + _input_schema = input_schema.trace(mod, inputs[0]) + initialize_from_schema(mod, _input_schema) - _input_schema = input.trace(mod, inputs[0]) if _input_schema not in mod.__input_schemas: mod.__input_schemas += (_input_schema,) - mod.__output_schemas += (output.trace(mod, _input_schema, outputs),) + mod.__output_schemas += (output_schema.trace(mod, _input_schema, outputs),) + + if not isinstance(mod, torch.jit.TracedModule): + accepts_batch, requires_batch = check_batch_arg(mod) + + if requires_batch and batch is None: + raise ValueError( + f"Found module f{mod} that requires a batch. " + "`trace` was called without providing a batch. " + "Please provide a batch argument to `trace`. " + ) + if accepts_batch and hasattr(mod, "compute_feature_schema"): + feature_schema = input_schema.tensors(batch.features) + required_feature_schema = mod.compute_feature_schema(feature_schema) + mod.__feature_schema = required_feature_schema + elif requires_batch: + required_feature_schema = input_schema.tensors(batch.features) + mod.__feature_schema = required_feature_schema def add_hook(m): - custom_modules = list(output.dispatcher.registry.keys()) + custom_modules = list(output_schema.dispatcher.registry.keys()) if m and isinstance(m, tuple(custom_modules[1:])): return @@ -261,7 +294,7 @@ def add_hook(m): return module_out -def features(module: nn.Module) -> Schema: +def feature_schema(module: nn.Module) -> Schema: """Extract the feature schema from a PyTorch Module. This function operates by applying the `get_feature_schema` method @@ -285,15 +318,15 @@ def features(module: nn.Module) -> Schema: def get_feature_schema(module): nonlocal feature_schema - if hasattr(module, "feature_schema"): - feature_schema += module.feature_schema + if hasattr(module, "__feature_schema"): + feature_schema += module.__feature_schema module.apply(get_feature_schema) return feature_schema -def targets(module: nn.Module) -> Schema: +def target_schema(module: nn.Module) -> Schema: """ Extract the target schema from a PyTorch Module. @@ -326,7 +359,7 @@ def get_target_schema(module): return target_schema -def setup_schema(module: nn.Module, schema: Schema): +def initialize_from_schema(module: nn.Module, schema: Schema): """ Set up a schema for a given module. @@ -340,15 +373,18 @@ def setup_schema(module: nn.Module, schema: Schema): from merlin.models.torch.block import BlockContainer, ParallelBlock - if hasattr(module, "setup_schema"): - module.setup_schema(schema) + if hasattr(module, "initialize_from_schema") and not getattr( + module, "_initialized_from_schema", False + ): + module.initialize_from_schema(schema) + module._initialized_from_schema = True elif isinstance(module, ParallelBlock): for branch in module.branches.values(): - setup_schema(branch, schema) + initialize_from_schema(branch, schema) elif isinstance(module, BlockContainer) and module: - setup_schema(module[0], schema) + initialize_from_schema(module[0], schema) @select.register(Schema) @@ -388,6 +424,9 @@ def select_schema(schema: Schema, selection: Selection) -> Schema: elif isinstance(selection, (Tags, TagSet)): selected = schema.select_by_tag(selection) elif isinstance(selection, str): + if selection == "*": + return schema + selected = Schema([schema[selection]]) elif isinstance(selection, (list, tuple)): if all(isinstance(s, str) for s in selection): @@ -404,6 +443,35 @@ def select_schema(schema: Schema, selection: Selection) -> Schema: return selected +def select_union(*selections: Selection) -> Selection: + """ + Combine selections into a single selection. + + This function returns a new function `combined_select` that, when called, + will perform the union operation on all the input selections. + + Parameters + ---------- + *selections : Selection + Variable length argument list of Selection instances. + + Returns + ------- + Selection + A function that takes a Schema as input and returns a Schema which + is the union of all selections. + """ + + def combined_select(schema: Schema) -> Schema: + output = Schema() + for s in selections: + output += select(schema, s) + + return output + + return combined_select + + def selection_name(selection: Selection) -> str: """ Get the name of the selection. @@ -441,7 +509,7 @@ class Selectable: A mixin to allow to be selectable by schema. """ - def setup_schema(self, schema: Schema): + def initialize_from_schema(self, schema: Schema): """ Setup the schema for this selectable. @@ -484,8 +552,11 @@ def select(self, selection: Selection) -> "Selectable": raise NotImplementedError() -@output.register_tensor(torch.Tensor) +@output_schema.register_tensor(torch.Tensor) def _tensor_to_schema(input, name="output"): + if input is None: + return Schema([ColumnSchema(name)]) + kwargs = dict(dims=input.shape[1:], dtype=input.dtype) if len(input.shape) > 1 and input.dtype != torch.int32: @@ -494,29 +565,149 @@ def _tensor_to_schema(input, name="output"): return Schema([ColumnSchema(name, **kwargs)]) -@input.register_tensor(torch.Tensor) +@input_schema.register_tensor(torch.Tensor) def _(input): return _tensor_to_schema(input, "input") -@input.register_tensor(Dict[str, torch.Tensor]) -@output.register_tensor(Dict[str, torch.Tensor]) +@input_schema.register_tensor(Dict[str, torch.Tensor]) +@output_schema.register_tensor(Dict[str, torch.Tensor]) def _(input): output = Schema() for k, v in sorted(input.items()): - output += _tensor_to_schema(v, k) + if k.endswith("__offsets"): + name = k[: -len("__offsets")] + kwargs = dict(dtype=input[f"{name}__values"].dtype) + output += Schema([ColumnSchema(name, **kwargs)]) + elif k.endswith("__values"): + continue + else: + output += _tensor_to_schema(v, k) return output -@input.register_tensor(Tuple[torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor]) -@input.register_tensor(Tuple[torch.Tensor, torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor, torch.Tensor]) -@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -@input.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -@output.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +@input_schema.register_tensor(Tuple[torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor]) +@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor, Optional[torch.Tensor]]) +@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) +@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) +@input_schema.register_tensor( + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +) +@output_schema.register_tensor( + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +) +@input_schema.register_tensor( + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +) +@output_schema.register_tensor( + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +) +@input_schema.register_tensor( + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +) +@output_schema.register_tensor( + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +) +@input_schema.register_tensor( + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +) +@output_schema.register_tensor( + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +) +@input_schema.register_tensor( + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +) +@output_schema.register_tensor( + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +) +@input_schema.register_tensor( + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +) +@output_schema.register_tensor( + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ] +) def _(input): output = Schema() @@ -524,3 +715,12 @@ def _(input): output += _tensor_to_schema(v, str(i)) return output + + +@output_schema.register_tensor(Batch) +def _(input): + schema = Schema() + schema += output_schema.tensors(input.features) + schema += output_schema.tensors(input.targets) + + return schema diff --git a/merlin/models/torch/transforms/agg.py b/merlin/models/torch/transforms/agg.py index 552fcf1d32..a340b87ce0 100644 --- a/merlin/models/torch/transforms/agg.py +++ b/merlin/models/torch/transforms/agg.py @@ -110,7 +110,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: _sorted_tensors = [] for tensor in sorted_tensors: if tensor.dim() < max_dims: - _sorted_tensors.append(tensor.unsqueeze(1)) + _sorted_tensors.append(tensor.unsqueeze(-1)) else: _sorted_tensors.append(tensor) sorted_tensors = _sorted_tensors @@ -193,6 +193,53 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: return torch.stack(sorted_tensors, dim=self.dim).float() +@registry.register("element-wise-sum") +class ElementWiseSum(AggModule): + """Element-wise sum of tensors. + + The input dictionary will be sorted by name before concatenation. + The sum is computed along the first dimension (default for Stack class). + + Example usage:: + >>> ewsum = ElementWiseSum() + >>> feature1 = torch.tensor([[1, 2], [3, 4]]) # Shape: [batch_size, feature_dim] + >>> feature2 = torch.tensor([[5, 6], [7, 8]]) # Shape: [batch_size, feature_dim] + >>> input_dict = {"feature1": feature1, "feature2": feature2} + >>> output = ewsum(input_dict) + >>> print(output) + tensor([[ 6, 8], + [10, 12]]) # Shape: [batch_size, feature_dim] + + """ + + def __init__(self): + super().__init__() + self.stack = Stack(dim=0) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Performs an element-wise sum of input tensors. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + A dictionary where keys are the names of the tensors + and values are the tensors to be summed. + + Returns + ------- + torch.Tensor + A tensor that is the result of performing an element-wise sum + of the input tensors. + + Raises + ------ + RuntimeError + If the input tensor shapes don't match for stacking. + """ + return self.stack(inputs).sum(dim=0) + + class MaybeAgg(BlockContainer): """ This class is designed to conditionally apply an aggregation operation diff --git a/merlin/models/torch/transforms/bias.py b/merlin/models/torch/transforms/bias.py new file mode 100644 index 0000000000..67ba7eb6f2 --- /dev/null +++ b/merlin/models/torch/transforms/bias.py @@ -0,0 +1,52 @@ +import torch +from torch import nn + + +class LogitsTemperatureScaler(nn.Module): + """ + A PyTorch Module for scaling logits with a given temperature value. + + This module is useful for implementing temperature scaling in neural networks, + a technique often used to soften or sharpen the output distribution of a classifier. + A temperature value closer to 0 makes the output probabilities more extreme + (either closer to 0 or 1), while a value closer to 1 makes the distribution + closer to uniform. + + Parameters + ---------- + temperature : float + The temperature value used for scaling. Must be a positive float in the range (0.0, 1.0]. + + Raises + ------ + ValueError + If the temperature value is not a float or is out of the range (0.0, 1.0]. + """ + + def __init__(self, temperature: float): + super().__init__() + + if not isinstance(temperature, float): + raise ValueError(f"Invalid temperature type: {type(temperature)}") + if not 0.0 < temperature <= 1.0: + raise ValueError( + f"Invalid temperature value: {temperature} ", "Must be in the range (0.0, 1.0]" + ) + + self.temperature = temperature + + def forward(self, logits: torch.Tensor) -> torch.Tensor: + """ + Apply temperature scaling to the input logits. + + Parameters + ---------- + logits : torch.Tensor + The input logits to be scaled. + + Returns + ------- + torch.Tensor + The scaled logits. + """ + return logits / self.temperature diff --git a/merlin/models/torch/transforms/sequences.py b/merlin/models/torch/transforms/sequences.py new file mode 100644 index 0000000000..4f96fc1897 --- /dev/null +++ b/merlin/models/torch/transforms/sequences.py @@ -0,0 +1,526 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from merlin.models.torch.batch import Batch, Sequence +from merlin.models.torch.block import BatchBlock +from merlin.models.torch.router import RouterBlock +from merlin.models.torch.schema import LazySchemaModuleMixin, Selection, select +from merlin.schema import Schema, Tags + + +class TabularPadding(BatchBlock): + """A PyTorch module for padding tabular sequence data. + + Parameters + ---------- + schema : Schema + The schema of the tabular data, which defines the column names of input features. + selection : Selection + The selection of the tabular data, which defines the column names of the + sequence input features. + max_sequence_length : Optional[int], default=None + The maximum length of the sequences after padding. + If None, sequences will be padded to the maximum length in the current batch. + + Example usage:: + features = { + 'feature1': torch.tensor([[4, 3], [5, 2]), + 'feature2': torch.tensor([[3,8], [7,9]]) + } + schema = Schema(["feature1", "feature2"]) + _max_sequence_length = 10 + padding_op = TabularBatchPadding( + schema=schema, max_sequence_length=_max_sequence_length + ) + padded_batch = padding_op(Batch(features)) + + Notes: + - If the schema contains continuous list features, + ensure that they are normalized within the range of [0, 1]. + This is necessary because we will be padding them + to a max_sequence_length using the minimum value of 0.0. + - The current class only supports right padding. + """ + + def __init__( + self, + schema: Optional[Schema] = None, + selection: Optional[Selection] = Tags.SEQUENCE, + max_sequence_length: Optional[int] = None, + name: Optional[str] = None, + ): + _padding = TabularPaddingModule( + schema=schema, selection=selection, max_sequence_length=max_sequence_length + ) + + if selection is None: + _to_add = _padding + else: + _to_add = RouterBlock(schema) + _to_add.add_route(selection, _padding) + + super().__init__(_to_add, name=name) + + +class TabularPaddingModule(nn.Module): + """A PyTorch module for padding tabular sequence data.""" + + def __init__( + self, + schema: Optional[Schema] = None, + selection: Selection = Tags.SEQUENCE, + max_sequence_length: Optional[int] = None, + ): + super().__init__() + self.selection = selection + if schema: + self.initialize_from_schema(schema) + self._initialized_from_schema = True + self.max_sequence_length = max_sequence_length + self.padding_idx = 0 + + def initialize_from_schema(self, schema: Schema): + self.schema = schema + if self.selection: + self.features: List[str] = self.schema.column_names + self.seq_features = select(self.schema, self.selection).column_names + else: + self.features = self.schema.column_names + self.seq_features = self.schema.column_names + + def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Batch) -> Batch: + if not torch.jit.isinstance(inputs, Dict[str, torch.Tensor]): + raise RuntimeError( + "TabularPaddingModule expects a dictionary of tensors as input, ", f"got: {inputs}" + ) + + _max_sequence_length = self.max_sequence_length + if not torch.jit.isinstance(_max_sequence_length, int): + # Infer the maximum length from the current batch + batch_max_sequence_length = 0 + for key, val in inputs.items(): + if key.endswith("__offsets"): + offsets = val + max_row_length = int(torch.max(offsets[1:] - offsets[:-1])) + batch_max_sequence_length = max(max_row_length, batch_max_sequence_length) + _max_sequence_length = batch_max_sequence_length + + # Store the non-padded lengths of list features + seq_inputs_lengths = self._get_sequence_lengths(inputs) + seq_shapes: List[torch.Tensor] = list(seq_inputs_lengths.values()) + if not torch.all(torch.stack([torch.all(x == seq_shapes[0]) for x in seq_shapes])): + raise ValueError( + "The sequential inputs must have the same length for each row in the batch, " + f"but they are different: {seq_shapes}" + ) + # Pad the features of the batch + batch_padded = {} + for key, value in batch.features.items(): + if key.endswith("__offsets"): + col_name = key[: -len("__offsets")] + if col_name in self.features: + padded_values = self._pad_ragged_tensor( + batch.features[f"{col_name}__values"], value, _max_sequence_length + ) + batch_padded[col_name] = padded_values + elif key.endswith("__values"): + continue + else: + col_name = key + if col_name in self.features and seq_inputs_lengths.get(col_name) is not None: + # pad dense list features + batch_padded[col_name] = self._pad_dense_tensor(value, _max_sequence_length) + + # Pad targets of the batch + targets_padded = None + if batch.targets is not None: + targets_padded = {} + for key, value in batch.targets.items(): + if key.endswith("__offsets"): + col_name = key[: -len("__offsets")] + padded_values = self._pad_ragged_tensor( + batch.targets[f"{col_name}__values"], value, _max_sequence_length + ) + targets_padded[col_name] = padded_values + elif key.endswith("__values"): + continue + else: + targets_padded[key] = value + + return Batch( + features=batch_padded, targets=targets_padded, sequences=Sequence(seq_inputs_lengths) + ) + + def _get_sequence_lengths(self, sequences: Dict[str, torch.Tensor]): + """Compute the effective length of each sequence in a dictionary of sequences.""" + seq_inputs_lengths = {} + for key, val in sequences.items(): + if key.endswith("__offsets"): + seq_inputs_lengths[key[: -len("__offsets")]] = val[1:] - val[:-1] + elif key in self.seq_features: + seq_inputs_lengths[key] = (val != self.padding_idx).sum(-1) + return seq_inputs_lengths + + def _squeeze(self, tensor: torch.Tensor): + """Squeeze a tensor of shape (N,1) to shape (N).""" + if len(tensor.shape) == 2: + return tensor.squeeze(1) + return tensor + + def _get_indices(self, offsets: torch.Tensor, diff_offsets: torch.Tensor): + """Compute indices for a sparse tensor from offsets and their differences.""" + row_ids = torch.arange(len(offsets) - 1, device=offsets.device) + row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets) + row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets) + col_ids = ( + torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated + ) + indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], dim=1) + return indices + + def _pad_ragged_tensor(self, values: torch.Tensor, offsets: torch.Tensor, padding_length: int): + """Pad a ragged features represented by "values" and "offsets" to a dense tensor + of length `padding_length`. + """ + values = self._squeeze(values) + offsets = self._squeeze(offsets) + num_rows = len(offsets) - 1 + diff_offsets = offsets[1:] - offsets[:-1] + max_length = int(diff_offsets.max()) + indices = self._get_indices(offsets, diff_offsets) + sparse_tensor = torch.sparse_coo_tensor( + indices.T, values, torch.Size([num_rows, max_length]), device=values.device + ) + + return self._pad_dense_tensor(sparse_tensor.to_dense(), padding_length) + + def _pad_dense_tensor(self, tensor: torch.Tensor, length: int) -> torch.Tensor: + """Pad a dense tensor along its second dimension to a specified length.""" + if len(tensor.shape) == 2: + pad_diff = length - tensor.shape[1] + return F.pad(input=tensor, pad=(0, pad_diff, 0, 0)) + return tensor + + +class BroadcastToSequence(nn.Module, LazySchemaModuleMixin): + """ + A PyTorch module to broadcast features to match the sequence length. + + BroadcastToSequence is a PyTorch module designed to facilitate broadcasting + of specific features within a given data schema to match a given sequence length. + This can be particularly useful in sequence-based neural networks, where different + types of inputs need to be processed in sync within the network, and all inputs need + to be of the same length. + + For example, in a sequence-to-sequence learning problem, one might have a feature + representing a constant property for each sequence (like an ID or a group), and you + want this feature to be available at each time step. In this case, you can use + BroadcastToSequence to 'broadcast' this feature along the time dimension, + creating a copy for each time step. + + Parameters + ---------- + to_broadcast : Selection + The features that need to be broadcasted. + sequence : Selection + The sequence features. + + """ + + def __init__( + self, + to_broadcast: Selection, + sequence: Selection, + schema: Optional[Schema] = None, + ): + super().__init__() + self.to_broadcast = to_broadcast + self.sequence = sequence + self.to_broadcast_features: List[str] = [] + self.sequence_features: List[str] = [] + if schema: + self.initialize_from_schema(schema) + + def initialize_from_schema(self, schema: Schema): + """ + Initialize the module from a schema. + + Parameters + ---------- + schema : Schema + The input-schema of this module + """ + super().initialize_from_schema(schema) + self.schema = schema + self.to_broadcast_features = select(schema, self.to_broadcast).column_names + self.sequence_features = select(schema, self.sequence).column_names + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward propagation method. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + The inputs dictionary containing the tensors to be broadcasted. + + Returns + ------- + Dict[str, torch.Tensor] + The dictionary containing the broadcasted tensors. + + Raises + ------ + RuntimeError + If a tensor has an unsupported number of dimensions. + """ + + outputs = {} + seq_length = self.get_seq_length(inputs) + + # Iterate over the to_broadcast_features and broadcast each tensor to the sequence length + for key, val in inputs.items(): + if key in self.to_broadcast_features: + # Check the dimension of the original tensor + if len(val.shape) == 1: # for 1D tensor (batch dimension only) + broadcasted_tensor = val.unsqueeze(1).repeat(1, seq_length) + elif len(val.shape) == 2: # for 2D tensor (batch dimension + feature dimension) + broadcasted_tensor = val.unsqueeze(1).repeat(1, seq_length, 1) + else: + raise RuntimeError(f"Unsupported number of dimensions: {len(val.shape)}") + + # Update the inputs dictionary with the broadcasted tensor + outputs[key] = broadcasted_tensor + else: + outputs[key] = val + + return outputs + + def get_seq_length(self, inputs: Dict[str, torch.Tensor]) -> int: + """ + Get the sequence length from inputs. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + The inputs dictionary. + + Returns + ------- + int + The sequence length. + """ + + if len(self.sequence_features) == 0: + raise RuntimeError("No sequence features found in the inputs.") + + first_feat: str = self.sequence_features[0] + + if first_feat + "__offsets" in inputs: + return inputs[first_feat + "__offsets"][-1].item() + + return inputs[first_feat].shape[1] + + +class TabularPredictNext(BatchBlock): + """A Batchblock instance for preparing sequential inputs and targets + for next-item prediction. The target is extracted from the shifted + sequence of the target feature and the sequential input features + are truncated in the last position. + + Parameters + ---------- + target : Optional[Selection], default=Tags.ID + The sequential input column(s) that will be used to extract the target. + Targets can be one or multiple input features with the same sequence length. + schema : Optional[Schema] + The schema with the sequential columns to be truncated + apply_padding : Optional[bool], default=True + Whether to pad sequential inputs before extracting the target(s). + max_sequence_length : Optional[int], default=None + The maximum length of the sequences after padding. + If None, sequences will be padded to the maximum length in the current batch. + + Example usage:: + batch_output = transform(batch) + + features = { + 'feature1': torch.tensor([[4, 3], [5, 2]), + 'feature2': torch.tensor([[3,8], [7,9]]) + } + schema = Schema(["feature1", "feature2"]) + next_item_op = TabularPredictNext( + schema=schema, target='feature1' + ) + transformed_batch = next_item_op(Batch(features)) + """ + + def __init__( + self, + target: Selection = Tags.ID, + schema: Optional[Schema] = None, + apply_padding: bool = True, + max_sequence_length: int = None, + name: Optional[str] = None, + ): + super().__init__( + TabularPredictNextModule( + schema=schema, + target=target, + apply_padding=apply_padding, + max_sequence_length=max_sequence_length, + ), + name=name, + ) + + +class TabularSequenceTransform(nn.Module): + """Base PyTorch module for preparing targets from a batch of sequential inputs. + Parameters + ---------- + target : Optional[Selection], default=Tags.ID + The sequential input column that will be used to extract the target. + In case of multiple targets, either a list of target feature names + or a shared Tag indicating the targets should be provided. + schema : Optional[Schema] + The schema with the sequential columns to be truncated + apply_padding : Optional[bool], default=True + Whether to pad sequential inputs before extracting the target(s). + max_sequence_length : Optional[int], default=None + The maximum length of the sequences after padding. + If None, sequences will be padded to the maximum length in the current batch. + """ + + def __init__( + self, + target: Optional[Selection] = Tags.ID, + schema: Optional[Schema] = None, + apply_padding: bool = True, + max_sequence_length: int = None, + ): + super().__init__() + self.target = target + if schema: + self.initialize_from_schema(schema) + self._initialized_from_schema = True + self.padding_idx = 0 + self.apply_padding = apply_padding + if self.apply_padding: + self.padding_operator = TabularPadding( + schema=self.schema, max_sequence_length=max_sequence_length + ) + + def initialize_from_schema(self, schema: Schema): + self.schema = schema + self.features: List[str] = self.schema.column_names + target = select(self.schema, self.target) + if not target: + raise ValueError( + f"The target '{self.target}' was not found in the " + f"provided sequential schema: {self.schema}" + ) + self.target_name = self._get_target(target) + + def _get_target(self, target): + return target.column_names + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Batch, **kwargs + ) -> Batch: + raise NotImplementedError() + + def _check_seq_inputs_targets(self, batch: Batch): + self._check_input_sequence_lengths(batch) + self._check_target_shape(batch) + + def _check_target_shape(self, batch: Batch): + for name in self.target_name: + if name not in batch.features: + raise ValueError(f"Inputs features do not contain target column ({name})") + + target = batch.features[name] + if target.ndim < 2: + raise ValueError( + f"Sequential target column ({name}) " + f"must be a 2D tensor, but shape is {target.ndim}" + ) + lengths = batch.sequences.length(name) + if any(lengths <= 1): + raise ValueError( + f"2nd dim of target column ({name})" + "must be greater than 1 for sequential input to be shifted as target" + ) + + def _check_input_sequence_lengths(self, batch: Batch): + if not batch.sequences.lengths: + raise ValueError( + "The input `batch` should include information about input sequences lengths" + ) + sequence_lengths = torch.stack([batch.sequences.length(name) for name in self.features]) + assert torch.all(sequence_lengths.eq(sequence_lengths[0])), ( + "All tabular sequence features need to have the same sequence length, " + f"found {sequence_lengths}" + ) + + +class TabularPredictNextModule(TabularSequenceTransform): + """A PyTorch module for preparing tabular sequence data for next-item prediction.""" + + def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Batch) -> Batch: + if self.apply_padding: + batch = self.padding_operator(batch) + self._check_seq_inputs_targets(batch) + + # Shifts the target column to be the next item of corresponding input column + new_targets: Dict[str, torch.Tensor] = dict() + for name in self.target_name: + new_target = batch.features[name] + new_target = new_target[:, 1:] + new_targets[name] = new_target + + # Removes the last item of the sequence, as it belongs to the target + new_inputs = dict() + for k, v in batch.features.items(): + if k in self.features: + new_inputs[k] = v[:, :-1] + + # Generates information about the new lengths and causal masks + new_lengths, causal_masks = {}, {} + for name in self.features: + new_lengths[name] = batch.sequences.lengths[name] - 1 + _max_length = list(new_targets.values())[0].shape[ + -1 + ] # all new targets have same output sequence length + causal_mask = self._generate_causal_mask(list(new_lengths.values())[0], _max_length) + for name in self.features: + causal_masks[name] = causal_mask + + return Batch( + features=new_inputs, + targets=new_targets, + sequences=Sequence(new_lengths, masks=causal_masks), + ) + + def _generate_causal_mask(self, seq_lengths: torch.Tensor, max_len: int): + """ + Generate a 2D mask from a tensor of sequence lengths. + """ + return torch.arange(max_len)[None, :] < seq_lengths[:, None] diff --git a/merlin/models/torch/transforms/tuple.py b/merlin/models/torch/transforms/tuple.py new file mode 100644 index 0000000000..9b995fdffb --- /dev/null +++ b/merlin/models/torch/transforms/tuple.py @@ -0,0 +1,228 @@ +import sys +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn + +from merlin.schema import Schema + + +def ToTuple(schema: Schema) -> "ToTupleModule": + """ + Creates a ToTupleModule for a given schema. + + This function is especially useful for serving models with Triton, + as Triton doesn't allow models that output a dictionary. Instead, + by using this function, models can be modified to output tuples. + + Parameters + ---------- + schema : Schema + Input schema for which a ToTupleModule is to be created. + + Returns + ------- + ToTupleModule + A ToTupleModule corresponding to the length of the given schema. + The output can vary from ToTuple1 to ToTuple10. + + Raises + ------ + ValueError + If the length of the schema is more than 10, + a ValueError is raised with an appropriate error message. + + Example usage :: + >>> import torch + >>> schema = Schema(["a", "b", "c"]) + >>> ToTupleModule = ToTuple(schema) + >>> tensor_dict = {'a': torch.tensor([1]), 'b': torch.tensor([2.]), 'c': torch.tensor([2.])} + >>> output = ToTupleModule(tensor_dict) + >>> print(output) + (tensor([1]), tensor([2.]), tensor([2.])) + """ + schema_length = len(schema) + + if schema_length <= 10: + ToTupleClass = getattr(sys.modules[__name__], f"ToTuple{schema_length}") + return ToTupleClass(input_schema=schema) + else: + raise ValueError(f"Cannot convert schema of length {schema_length} to a tuple") + + +class ToTupleModule(nn.Module): + def __init__(self, input_schema: Optional[Schema] = None): + super().__init__() + if input_schema is not None: + self.initialize_from_schema(input_schema) + self._initialized_from_schema = True + + def initialize_from_schema(self, input_schema: Schema): + self._input_schema = input_schema + self._column_names = input_schema.column_names + + def value_list(self, inputs: Dict[str, torch.Tensor]) -> List[torch.Tensor]: + outputs: List[torch.Tensor] = [] + + if not hasattr(self, "_column_names"): + raise RuntimeError("initialize_from_schema() must be called before value_list()") + + for col in self._column_names: + outputs.append(inputs[col]) + + return outputs + + +class ToTuple1(ToTupleModule): + """Converts a dictionary of tensors of length=1 to a tuple of tensors.""" + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor]: + _list = list(inputs.values()) + return (_list[0],) + + +class ToTuple2(ToTupleModule): + """Converts a dictionary of tensors of length=2 to a tuple of tensors.""" + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + _list = self.value_list(inputs) + return (_list[0], _list[1]) + + +class ToTuple3(ToTupleModule): + """Converts a dictionary of tensors of length=3 to a tuple of tensors.""" + + def forward( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + _list = self.value_list(inputs) + return (_list[0], _list[1], _list[2]) + + +class ToTuple4(ToTupleModule): + """Converts a dictionary of tensors of length=4 to a tuple of tensors.""" + + def forward( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + _list = self.value_list(inputs) + return (_list[0], _list[1], _list[2], _list[3]) + + +class ToTuple5(ToTupleModule): + """Converts a dictionary of tensors of length=5 to a tuple of tensors.""" + + def forward( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + _list = self.value_list(inputs) + return (_list[0], _list[1], _list[2], _list[3], _list[4]) + + +class ToTuple6(ToTupleModule): + """Converts a dictionary of tensors of length=6 to a tuple of tensors.""" + + def forward( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + _list = self.value_list(inputs) + return (_list[0], _list[1], _list[2], _list[3], _list[4], _list[5]) + + +class ToTuple7(ToTupleModule): + """Converts a dictionary of tensors of length=7 to a tuple of tensors.""" + + def forward( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + _list = self.value_list(inputs) + return (_list[0], _list[1], _list[2], _list[3], _list[4], _list[5], _list[6]) + + +class ToTuple8(ToTupleModule): + """Converts a dictionary of tensors of length=8 to a tuple of tensors.""" + + def forward( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + _list = self.value_list(inputs) + return (_list[0], _list[1], _list[2], _list[3], _list[4], _list[5], _list[6], _list[7]) + + +class ToTuple9(ToTupleModule): + """Converts a dictionary of tensors of length=9 to a tuple of tensors.""" + + def forward( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + _list = list(inputs.values()) + return ( + _list[0], + _list[1], + _list[2], + _list[3], + _list[4], + _list[5], + _list[6], + _list[7], + _list[8], + ) + + +class ToTuple10(ToTupleModule): + """Converts a dictionary of tensors of length=10 to a tuple of tensors.""" + + def forward( + self, inputs: Dict[str, torch.Tensor] + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + _list = list(inputs.values()) + return ( + _list[0], + _list[1], + _list[2], + _list[3], + _list[4], + _list[5], + _list[6], + _list[7], + _list[8], + _list[9], + ) diff --git a/merlin/models/torch/utils/module_utils.py b/merlin/models/torch/utils/module_utils.py index 4dcc965c51..1117b9f2f2 100644 --- a/merlin/models/torch/utils/module_utils.py +++ b/merlin/models/torch/utils/module_utils.py @@ -140,7 +140,7 @@ def module_test(module: nn.Module, input_data, method="script", schema_trace=Tru module.to(device=input_data.device()) kwargs["batch"] = input_data input_data = input_data.features - elif "batch" in kwargs: + elif "batch" in kwargs and isinstance(kwargs["batch"], Batch): module.to(device=kwargs["batch"].device()) # Check if the module can be called with the provided inputs diff --git a/merlin/models/torch/utils/traversal_utils.py b/merlin/models/torch/utils/traversal_utils.py index d458980cce..51c9ed3cdb 100644 --- a/merlin/models/torch/utils/traversal_utils.py +++ b/merlin/models/torch/utils/traversal_utils.py @@ -127,6 +127,9 @@ def leaf(module) -> nn.Module: return child.leaf() return leaf(child) else: + if isinstance(module, containers) and not hasattr(module, "items"): + return module[-1] + # If more than one child, throw an exception. raise ValueError( f"Module {module} has multiple children, cannot determine the deepest child." diff --git a/pyproject.toml b/pyproject.toml index 0d5b2942be..63492e2264 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,5 +62,7 @@ markers = [ "integration", "unit", "changed", - "unchanged" + "unchanged", + "singlegpu", + "multigpu" ] diff --git a/pytest.ini b/pytest.ini index ee9e7ef9be..467baa6f57 100644 --- a/pytest.ini +++ b/pytest.ini @@ -15,3 +15,5 @@ markers = horovod: mark as requiring horovod changed: mark as requiring changed files always: mark as always running + multigpu: Tests only run in multiple-GPU environments + singlegpu: Optional marker to run tests in single-GPU environments. Usually used when running in both single- and multi-GPU. diff --git a/requirements/docs.txt b/requirements/docs.txt index 2e72b146ad..6f335f911f 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,12 +1,11 @@ attrs==21.4.0 Jinja2<3.1 linkify-it-py==1.0.3 -myst-nb==0.13.2 +myst-nb natsort==8.1.0 recommonmark==0.7.1 -sphinx_rtd_theme==1.0.0 +sphinx_book_theme~=1.0.1 sphinx-external-toc==0.2.4 sphinx-multiversion@git+https://github.com/mikemckiernan/sphinx-multiversion.git -Sphinx==3.5.4 sphinxcontrib-copydirs@git+https://github.com/mikemckiernan/sphinxcontrib-copydirs.git ipython==8.2.0 diff --git a/requirements/tensorflow.txt b/requirements/tensorflow.txt index 95fa5dd10e..203cf5888c 100644 --- a/requirements/tensorflow.txt +++ b/requirements/tensorflow.txt @@ -1 +1 @@ -tensorflow>=2.9 +tensorflow>=2.9,<2.13 diff --git a/tests/benchmark/test_asvdb_transformers_next_item_prediction.py b/tests/benchmark/test_asvdb_transformers_next_item_prediction.py deleted file mode 100644 index 49aa27416e..0000000000 --- a/tests/benchmark/test_asvdb_transformers_next_item_prediction.py +++ /dev/null @@ -1,95 +0,0 @@ -from asvdb import ASVDb, BenchmarkResult, utils -from testbook import testbook - -from tests.conftest import REPO_ROOT, get_benchmark_info - - -@testbook( - REPO_ROOT / "examples/usecases/transformers-next-item-prediction.ipynb", - timeout=720, - execute=False, -) -def test_func(tb, tmpdir): - tb.inject( - f""" - import os - os.environ["INPUT_DATA_DIR"] = "/raid/data/booking" - os.environ["OUTPUT_DATA_DIR"] = "{tmpdir}" - os.environ["NUM_EPOCHS"] = '1' - """ - ) - tb.cells.pop(6) - tb.cells[ - 15 - ].source = """ - def process_data(): - wf = Workflow(filtered_sessions) - - wf.fit_transform(train_set_dataset).to_parquet( - os.path.join(OUTPUT_DATA_DIR, 'train_processed.parquet') - ) - wf.transform(validation_set_dataset).to_parquet( - os.path.join(OUTPUT_DATA_DIR, 'validation_processed.parquet') - ) - - wf.save(os.path.join(OUTPUT_DATA_DIR, 'workflow')) - - data_processing_runtime = timeit.timeit(process_data, number=1) - """ - tb.cells[ - 29 - ].source = """ - model.compile(run_eagerly=False, optimizer='adam', loss="categorical_crossentropy") - - def train_model(): - model.fit( - Dataset(os.path.join(OUTPUT_DATA_DIR, 'train_processed.parquet')), - batch_size=64, - epochs=NUM_EPOCHS, - pre=mm.SequenceMaskRandom( - schema=seq_schema, - target=target, - masking_prob=0.3, - transformer=transformer_block - ) - ) - - training_runtime = timeit.timeit(train_model, number=1) - """ - tb.execute_cell(list(range(0, 35))) - data_processing_runtime = tb.ref("data_processing_runtime") - training_runtime = tb.ref("training_runtime") - ndcg_at_10 = tb.ref("metrics")["ndcg_at_10"] - - bResult1 = BenchmarkResult( - funcName="", - argNameValuePairs=[ - ("notebook_name", "usecases/transformers-next-item-prediction"), - ("measurement", "data_processing_runtime"), - ], - result=data_processing_runtime, - ) - bResult2 = BenchmarkResult( - funcName="", - argNameValuePairs=[ - ("notebook_name", "usecases/transformers-next-item-prediction"), - ("measurement", "training_runtime"), - ], - result=training_runtime, - ) - bResult3 = BenchmarkResult( - funcName="", - argNameValuePairs=[ - ("notebook_name", "usecases/transformers-next-item-prediction"), - ("measurement", "ndcg_at_10"), - ], - result=ndcg_at_10, - ) - - bInfo = get_benchmark_info() - (repo, branch) = utils.getRepoInfo() - - db = ASVDb(dbDir="s3://nvtab-bench-asvdb/models_metric_tracking", repo=repo, branches=[branch]) - db.addResult(bInfo, bResult1) - db.addResult(bInfo, bResult2) - db.addResult(bInfo, bResult3) diff --git a/tests/integration/torch/__init__.py b/tests/integration/torch/__init__.py new file mode 100644 index 0000000000..598c04b960 --- /dev/null +++ b/tests/integration/torch/__init__.py @@ -0,0 +1,19 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +pytest.importorskip("torch") +pytest.importorskip("pytorch_lightning") diff --git a/tests/integration/torch/test_multi_gpu.py b/tests/integration/torch/test_multi_gpu.py new file mode 100644 index 0000000000..de510a7fc4 --- /dev/null +++ b/tests/integration/torch/test_multi_gpu.py @@ -0,0 +1,31 @@ +import pytest +import pytorch_lightning as pl + +import merlin.models.torch as mm + + +# TODO: This test is not complete because Lightning launches separate processes +# under the hood with the correct environment variables like `LOCAL_RANK`, but +# the pytest stays in the main process and tests only the LOCAL_RANK=0 case. +# Follow-up with proper test that ensures dataloader is working properly with +# e.g. global_rank > 0. +class TestMultiGPU: + @pytest.mark.multigpu + def test_multi_gpu(self, music_streaming_data): + schema = music_streaming_data.schema + data = music_streaming_data + model = mm.Model( + mm.TabularInputBlock(schema, init="defaults"), + mm.MLPBlock([5]), + mm.BinaryOutput(schema["click"]), + ) + + trainer = pl.Trainer(max_epochs=3, devices=2) + multi_loader = mm.MultiLoader(data, batch_size=2) + trainer.fit(model, multi_loader) + + # 100 rows total / 2 devices -> 50 rows per device + # 50 rows / 2 per batch -> 25 steps per device + assert trainer.num_training_batches == 25 + + assert trainer.global_rank == 0 # This should fail for node 1. diff --git a/tests/unit/tf/examples/test_usecase_ecommerce_session_based.py b/tests/unit/tf/examples/test_08_session_based_next_item_prediction.py similarity index 95% rename from tests/unit/tf/examples/test_usecase_ecommerce_session_based.py rename to tests/unit/tf/examples/test_08_session_based_next_item_prediction.py index 20272fa6b8..105dfda63c 100644 --- a/tests/unit/tf/examples/test_usecase_ecommerce_session_based.py +++ b/tests/unit/tf/examples/test_08_session_based_next_item_prediction.py @@ -6,7 +6,7 @@ pytest.importorskip("transformers") -p = "examples/usecases/ecommerce-session-based-next-item-prediction-for-fashion.ipynb" +p = "examples/08-Train-a-model-for-session-based-next-item-prediction.ipynb" @testbook( diff --git a/tests/unit/tf/examples/test_usecase_retrieval_with_hpo.py b/tests/unit/tf/examples/test_usecase_retrieval_with_hpo.py index 4fabbd8d14..f1727ba27d 100644 --- a/tests/unit/tf/examples/test_usecase_retrieval_with_hpo.py +++ b/tests/unit/tf/examples/test_usecase_retrieval_with_hpo.py @@ -3,11 +3,14 @@ from tests.conftest import REPO_ROOT +pytest.importorskip("plotly") optuna = pytest.importorskip("optuna") @testbook( - REPO_ROOT / "examples/usecases/retrieval-with-hyperparameter-optimization.ipynb", execute=False + REPO_ROOT / "examples/usecases/retrieval-with-hyperparameter-optimization.ipynb", + execute=False, + timeout=120, ) @pytest.mark.notebook def test_usecase_pretrained_embeddings(tb): diff --git a/tests/unit/tf/examples/test_usecase_transformers_next_item_prediction.py b/tests/unit/tf/examples/test_usecase_transformers_next_item_prediction.py deleted file mode 100644 index a0dc9d9fdd..0000000000 --- a/tests/unit/tf/examples/test_usecase_transformers_next_item_prediction.py +++ /dev/null @@ -1,57 +0,0 @@ -import shutil - -import pytest -from testbook import testbook - -from tests.conftest import REPO_ROOT - -pytest.importorskip("transformers") -utils = pytest.importorskip("merlin.systems.triton.utils") - -TRITON_SERVER_PATH = shutil.which("tritonserver") - - -@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found") -@testbook( - REPO_ROOT / "examples/usecases/transformers-next-item-prediction.ipynb", - timeout=720, - execute=False, -) -@pytest.mark.notebook -def test_next_item_prediction(tb, tmpdir): - tb.inject( - f""" - import os, random - os.environ["INPUT_DATA_DIR"] = "{tmpdir}" - os.environ["OUTPUT_DATA_DIR"] = "{tmpdir}" - from datetime import datetime, timedelta - from merlin.datasets.synthetic import generate_data - ds = generate_data('booking.com-raw', 10000) - df = ds.compute() - def generate_date(): - date = datetime.today() - if random.randint(0, 1): - date -= timedelta(days=7) - return date - df['checkin'] = [generate_date() for _ in range(df.shape[0])] - df['checkout'] = [generate_date() for _ in range(df.shape[0])] - df.to_csv('{tmpdir}/train_set.csv') - """ - ) - tb.cells.pop(6) - tb.cells[29].source = tb.cells[29].source.replace("epochs=5", "epochs=1") - tb.execute_cell(list(range(0, 38))) - - with utils.run_triton_server(f"{tmpdir}/ensemble", grpc_port=8001): - tb.execute_cell(list(range(38, len(tb.cells)))) - - tb.inject( - """ - logits_count = predictions.shape[1] - """ - ) - tb.execute_cell(len(tb.cells) - 1) - - cardinality = tb.ref("cardinality") - logits_count = tb.ref("logits_count") - assert logits_count == cardinality diff --git a/tests/unit/tf/examples/test_usecase_transformers_next_item_prediction_with_pretrained_embeddings.py b/tests/unit/tf/examples/test_usecase_transformers_next_item_prediction_with_pretrained_embeddings.py deleted file mode 100644 index cc7c0b4a70..0000000000 --- a/tests/unit/tf/examples/test_usecase_transformers_next_item_prediction_with_pretrained_embeddings.py +++ /dev/null @@ -1,38 +0,0 @@ -import shutil - -import pytest -from testbook import testbook - -from tests.conftest import REPO_ROOT - -pytest.importorskip("transformers") -utils = pytest.importorskip("merlin.systems.triton.utils") - -TRITON_SERVER_PATH = shutil.which("tritonserver") - - -@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found") -@testbook( - REPO_ROOT - / "examples/usecases/transformers-next-item-prediction-with-pretrained-embeddings.ipynb", - timeout=720, - execute=False, -) -@pytest.mark.notebook -def test_next_item_prediction(tb, tmpdir): - tb.inject( - f""" - import os, random - os.environ["OUTPUT_DATA_DIR"] = "{tmpdir}" - os.environ["NUM_EPOCHS"] = "1" - os.environ["NUM_EXAMPLES"] = "1_500" - os.environ["MINIMUM_SESSION_LENGTH"] = "2" - """ - ) - tb.execute_cell(list(range(0, 48))) - - with utils.run_triton_server(f"{tmpdir}/ensemble", grpc_port=8001): - tb.execute_cell(list(range(48, len(tb.cells)))) - - predicted_hashed_url_id = tb.ref("predicted_hashed_url_id").item() - assert predicted_hashed_url_id >= 0 and predicted_hashed_url_id <= 1002 diff --git a/tests/unit/tf/models/test_retrieval.py b/tests/unit/tf/models/test_retrieval.py index 4fd51525bd..c7531c5f0a 100644 --- a/tests/unit/tf/models/test_retrieval.py +++ b/tests/unit/tf/models/test_retrieval.py @@ -1,11 +1,12 @@ from pathlib import Path -import nvtabular as nvt +import numpy as np import pytest import tensorflow as tf import merlin.models.tf as mm from merlin.core.dispatch import make_df +from merlin.dataloader.ops.embeddings import EmbeddingOperator from merlin.io import Dataset from merlin.models.tf.metrics.topk import ( AvgPrecisionAt, @@ -24,6 +25,8 @@ def test_two_tower_shared_embeddings(): + nvt = pytest.importorskip("nvtabular") + train = make_df( { "user_id": [1, 3, 3, 4, 3, 1, 2, 4, 6, 7, 8, 9] * 100, @@ -435,6 +438,66 @@ def test_two_tower_model_topk_evaluation(ecommerce_data: Dataset, run_eagerly): assert all([metric >= 0 for metric in metrics.values()]) +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_two_tower_model_topk_evaluation_with_pretrained_emb(music_streaming_data, run_eagerly): + music_streaming_data.schema = music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]) + + cardinality = music_streaming_data.schema["item_category"].int_domain.max + 1 + pretrained_embedding = np.random.rand(cardinality, 12) + + loader_transforms = [ + EmbeddingOperator( + pretrained_embedding, + lookup_key="item_category", + embedding_name="pretrained_category_embeddings", + ), + ] + loader = mm.Loader( + music_streaming_data, + schema=music_streaming_data.schema.select_by_tag([Tags.USER, Tags.ITEM]), + batch_size=10, + transforms=loader_transforms, + ) + schema = loader.output_schema + + pretrained_embeddings = mm.PretrainedEmbeddings( + schema.select_by_tag(Tags.EMBEDDING), + output_dims=16, + ) + + schema = loader.output_schema + + query_input = mm.InputBlockV2(schema.select_by_tag(Tags.USER)) + query = mm.Encoder(query_input, mm.MLPBlock([4], no_activation_last_layer=True)) + candidate_input = mm.InputBlockV2( + schema.select_by_tag(Tags.ITEM), pretrained_embeddings=pretrained_embeddings + ) + candidate = mm.Encoder(candidate_input, mm.MLPBlock([4], no_activation_last_layer=True)) + model = mm.TwoTowerModelV2( + query, + candidate, + negative_samplers=["in-batch"], + ) + model.compile(optimizer="adam", run_eagerly=run_eagerly) + _ = testing_utils.model_test(model, loader) + + # Top-K evaluation + candidate_features_data = unique_rows_by_features(music_streaming_data, Tags.ITEM, Tags.ITEM_ID) + loader_candidates = mm.Loader( + candidate_features_data, + batch_size=16, + transforms=loader_transforms, + ) + + topk_model = model.to_top_k_encoder(loader_candidates, k=20, batch_size=16) + topk_model.compile(run_eagerly=run_eagerly) + + loader = mm.Loader(music_streaming_data, batch_size=32).map(mm.ToTarget(schema, "item_id")) + + metrics = topk_model.evaluate(loader, return_dict=True) + assert all([metric >= 0 for metric in metrics.values()]) + + @pytest.mark.parametrize("run_eagerly", [True, False]) @pytest.mark.parametrize("logits_pop_logq_correction", [True, False]) @pytest.mark.parametrize("loss", ["categorical_crossentropy", "bpr-max", "binary_crossentropy"]) @@ -884,7 +947,7 @@ def test_youtube_dnn_retrieval_v2(sequence_testing_data: Dataset, run_eagerly, t assert losses is not None -def test_two_tower_v2_export_embeddings( +def test_two_tower_v2_export_item_tower_embeddings( ecommerce_data: Dataset, ): user_schema = ecommerce_data.schema.select_by_tag(Tags.USER_ID) @@ -907,7 +970,38 @@ def test_two_tower_v2_export_embeddings( _check_embeddings(candidates, 100, 8, "item_id") -def test_mf_v2_export_embeddings( +def test_two_tower_v2_export_item_tower_embeddings_with_seq_item_features( + music_streaming_data: Dataset, +): + # Changing the schema of the multi-hot "item_genres" feature to be + # dense (not ragged) + music_streaming_data.schema["item_genres"] = music_streaming_data.schema[ + "item_genres" + ].with_shape(((0, None), (4, 4))) + schema = music_streaming_data.schema + user_schema = schema.select_by_tag(Tags.USER) + candidate_schema = schema.select_by_tag(Tags.ITEM) + + query = mm.Encoder(user_schema, mm.MLPBlock([8])) + candidate = mm.Encoder(candidate_schema, mm.MLPBlock([8])) + model = mm.TwoTowerModelV2( + query_tower=query, candidate_tower=candidate, negative_samplers=["in-batch"] + ) + + model, _ = testing_utils.model_test(model, music_streaming_data, reload_model=False) + + queries = model.query_embeddings( + music_streaming_data, batch_size=16, index=Tags.USER_ID + ).compute() + _check_embeddings(queries, 100, 8, "user_id") + + candidates = model.candidate_embeddings( + music_streaming_data, batch_size=16, index=Tags.ITEM_ID + ).compute() + _check_embeddings(candidates, 100, 8, "item_id") + + +def test_mf_v2_export_item_tower_embeddings( ecommerce_data: Dataset, ): model = mm.MatrixFactorizationModelV2( @@ -939,7 +1033,7 @@ def _check_embeddings(embeddings, extected_len, num_dim=8, index_name=None): assert embeddings.index.name == index_name -def test_youtube_dnn_v2_export_embeddings(sequence_testing_data: Dataset): +def test_youtube_dnn_v2_export_item_embeddings(sequence_testing_data: Dataset): to_remove = ["event_timestamp"] + ( sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE) .select_by_tag(Tags.CONTINUOUS) diff --git a/tests/unit/torch/blocks/test_attention.py b/tests/unit/torch/blocks/test_attention.py new file mode 100644 index 0000000000..0c8926e285 --- /dev/null +++ b/tests/unit/torch/blocks/test_attention.py @@ -0,0 +1,50 @@ +import pytest +import torch +from torch import nn + +from merlin.models.torch.blocks.attention import CrossAttentionBlock +from merlin.models.torch.utils import module_utils + + +class TestCrossAttentionBlock: + def setup_method(self): + # Set up a simple CrossAttentionBlock instance for testing. + self.cross = CrossAttentionBlock( + nn.TransformerEncoderLayer(10, 2, dim_feedforward=10, batch_first=True, dropout=0.0), + attention=nn.MultiheadAttention(10, 2, batch_first=True), + key="context", + seq_key="sequence", + ) + self.input_dict = {"context": torch.randn(1, 2, 10), "sequence": torch.randn(1, 6, 10)} + + def test_init(self): + assert self.cross.key == "context" + assert self.cross.seq_key == "sequence" + assert isinstance(self.cross.cross_attention, nn.ModuleList) + assert isinstance(self.cross.cross_attention[0], nn.MultiheadAttention) + + def test_forward(self): + out = self.cross(self.input_dict) + assert isinstance(out, torch.Tensor) + assert out.shape == self.input_dict["sequence"].shape + + def test_forward_torch_script(self): + out = module_utils.module_test(self.cross, self.input_dict) + assert isinstance(out, torch.Tensor) + assert out.shape == self.input_dict["sequence"].shape + + def test_get_seq_error(self): + with pytest.raises(RuntimeError, match="Could not find"): + self.cross.get_seq( + {"context": torch.randn(1, 10), "0": torch.randn(1, 10), "1": torch.randn(1, 10)} + ) + + with pytest.raises( + RuntimeError, match="Please set seq_key for when more than 2 keys are present" + ): + cross = CrossAttentionBlock( + attention=nn.MultiheadAttention(10, 2, batch_first=True), + ) + cross.get_seq( + {"context": torch.randn(1, 10), "0": torch.randn(1, 10), "1": torch.randn(1, 10)} + ) diff --git a/tests/unit/torch/blocks/test_dlrm.py b/tests/unit/torch/blocks/test_dlrm.py index 21d65e8561..4b1a2fc82c 100644 --- a/tests/unit/torch/blocks/test_dlrm.py +++ b/tests/unit/torch/blocks/test_dlrm.py @@ -84,7 +84,7 @@ def test_dlrm_block_no_categorical_features(self): schema = self.schema.remove_by_tag(Tags.CATEGORICAL) embedding_dim = 32 - with pytest.raises(ValueError, match="must have a categorical input"): + with pytest.raises(ValueError, match="not found in"): _ = mm.DLRMBlock( schema, dim=embedding_dim, diff --git a/tests/unit/torch/blocks/test_experts.py b/tests/unit/torch/blocks/test_experts.py new file mode 100644 index 0000000000..5f32bf3485 --- /dev/null +++ b/tests/unit/torch/blocks/test_experts.py @@ -0,0 +1,123 @@ +import pytest +import torch + +import merlin.models.torch as mm +from merlin.models.torch.blocks.experts import ( + CGCBlock, + ExpertGateBlock, + MMOEBlock, + PLEBlock, + PLEExpertGateBlock, +) +from merlin.models.torch.utils import module_utils + +dict_inputs = {"experts": torch.rand((10, 4, 5)), "shortcut": torch.rand((10, 8))} + + +class TestExpertGateBlock: + @pytest.fixture + def expert_gate(self): + return ExpertGateBlock(num_experts=4) + + def test_requires_dict_input(self, expert_gate): + with pytest.raises(RuntimeError, match="ExpertGateBlock requires a dictionary input"): + expert_gate(torch.rand((10, 5))) + + def test_forward_pass(self, expert_gate): + result = module_utils.module_test(expert_gate, dict_inputs) + assert result.shape == (10, 5) + + +class TestMMOEBlock: + def test_init(self): + mmoe = MMOEBlock(mm.MLPBlock([2, 2]), 2) + + assert isinstance(mmoe, MMOEBlock) + assert isinstance(mmoe[0], mm.ShortcutBlock) + assert len(mmoe[0][0].branches) == 2 + for i in range(2): + assert mmoe[0][0][str(i)][1].out_features == 2 + assert mmoe[0][0][str(i)][3].out_features == 2 + assert isinstance(mmoe[0][0].post[0], mm.Stack) + assert isinstance(mmoe[1], ExpertGateBlock) + assert mmoe[1][0][0].out_features == 2 + + def test_init_with_outputs(self): + outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) + outputs.prepend_for_each(mm.MLPBlock([2])) + outputs.prepend(MMOEBlock(mm.MLPBlock([2, 2]), 2, outputs)) + + assert isinstance(outputs.pre[0], MMOEBlock) + assert list(outputs.pre[0][1].keys()) == ["a", "b"] + + def test_forward(self): + mmoe = MMOEBlock(mm.MLPBlock([2, 2]), 2) + + outputs = module_utils.module_test(mmoe, torch.rand(5, 5)) + assert outputs.shape == (5, 2) + + def test_forward_with_outputs(self): + outputs = mm.ParallelBlock({"a": mm.BinaryOutput(), "b": mm.BinaryOutput()}) + outputs.prepend_for_each(mm.MLPBlock([2, 2])) + outputs.prepend(MMOEBlock(mm.MLPBlock([2, 2]), 2, outputs)) + + outputs = module_utils.module_test(outputs, torch.rand(5, 5)) + assert outputs["a"].shape == (5, 1) + assert outputs["b"].shape == (5, 1) + + +class TestPLEExpertGateBlock: + @pytest.fixture + def ple_expert_gate(self): + return PLEExpertGateBlock( + num_experts=6, task_experts=mm.repeat_parallel(mm.MLPBlock([5, 5]), 2), name="a" + ) + + def test_repr(self, ple_expert_gate): + assert "(task_experts)" in str(ple_expert_gate) + assert "(gate)" in str(ple_expert_gate) + + def test_requires_dict_input(self, ple_expert_gate): + with pytest.raises(RuntimeError, match="ExpertGateBlock requires a dictionary input"): + ple_expert_gate(torch.rand((10, 5))) + + def test_ple_forward(self, ple_expert_gate): + result = module_utils.module_test(ple_expert_gate, dict_inputs) + assert result.shape == (10, 5) + + +class TestCGCBlock: + @pytest.mark.parametrize("shared_gate", [True, False]) + def test_forward(self, music_streaming_data, shared_gate): + output_block = mm.TabularOutputBlock(music_streaming_data.schema, init="defaults") + cgc = CGCBlock( + mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + outputs=output_block, + shared_gate=shared_gate, + ) + + outputs = module_utils.module_test(cgc, torch.rand(5, 5)) + assert len(outputs) == len(output_block) + (2 if shared_gate else 0) + + +class TestPLEBlock: + def test_forward(self, music_streaming_data): + output_block = mm.TabularOutputBlock(music_streaming_data.schema, init="defaults") + ple = PLEBlock( + mm.MLPBlock([5]), + num_shared_experts=2, + num_task_experts=2, + depth=2, + outputs=output_block, + ) + + assert isinstance(ple[0], CGCBlock) + assert len(ple[0][1]) == len(output_block) + 1 + assert isinstance(ple[0][1]["experts"][0], ExpertGateBlock) + assert isinstance(ple[1], CGCBlock) + assert list(ple[1][1].branches.keys()) == list(ple[1][1].branches.keys()) + + outputs = module_utils.module_test(ple, torch.rand(5, 5)) + assert len(outputs) == len(output_block) diff --git a/tests/unit/torch/inputs/test_embedding.py b/tests/unit/torch/inputs/test_embedding.py index b9b0014845..9ccc320057 100644 --- a/tests/unit/torch/inputs/test_embedding.py +++ b/tests/unit/torch/inputs/test_embedding.py @@ -28,7 +28,7 @@ def test_init(self, item_id_col_schema, user_id_col_schema): def test_init_defaults(self, item_id_col_schema): table = EmbeddingTable() - table.setup_schema(item_id_col_schema) + table.initialize_from_schema(item_id_col_schema) assert table.dim == 8 @@ -211,7 +211,7 @@ def test_exceptions(self, item_id_col_schema): with pytest.raises(RuntimeError): table(torch.tensor([0, 1, 2])) - table.setup_schema(item_id_col_schema) + table.initialize_from_schema(item_id_col_schema) with pytest.raises(ValueError): table("a") diff --git a/tests/unit/torch/inputs/test_select.py b/tests/unit/torch/inputs/test_select.py index 37c15c784f..ea21e94f34 100644 --- a/tests/unit/torch/inputs/test_select.py +++ b/tests/unit/torch/inputs/test_select.py @@ -52,10 +52,10 @@ def test_select(self): assert select_user.select(ColumnSchema("user_id")).schema == user_id assert select_user.select(Tags.USER).schema == self.user_schema - def test_setup_schema(self): + def test_initialize_from_schema(self): select_user = mm.SelectKeys() - select_user.setup_schema(self.user_schema["user_id"]) - assert select_user.schema == Schema([self.user_schema["user_id"]]) + select_user.initialize_from_schema(self.user_schema[["user_id"]]) + assert select_user.schema == self.user_schema[["user_id"]] class TestSelectFeatures: @@ -72,8 +72,8 @@ def test_forward(self): outputs = mm.schema.trace(block, self.batch.features["session_id"], batch=self.batch) assert len(outputs) == 5 - assert mm.schema.input(block).column_names == ["input"] - assert mm.schema.features(block).column_names == [ + assert mm.input_schema(block).column_names == ["input"] + assert mm.feature_schema(block).column_names == [ "user_id", "country", "user_age", diff --git a/tests/unit/torch/inputs/test_tabular.py b/tests/unit/torch/inputs/test_tabular.py index a3ee2fb0a3..c2eff10475 100644 --- a/tests/unit/torch/inputs/test_tabular.py +++ b/tests/unit/torch/inputs/test_tabular.py @@ -41,6 +41,35 @@ def test_init_detaults(self): infer_embedding_dim(self.schema.select_by_name(name)), ) + def test_init_defaults_broadcast(self, sequence_testing_data): + schema = sequence_testing_data.schema + input_block = mm.TabularInputBlock(schema, init="broadcast-context") + + batch = mm.Batch.sample_from(sequence_testing_data, batch_size=10) + padded = mm.TabularPadding(schema)(batch) + outputs = module_utils.module_test(input_block, padded) + + con_cols = input_block["context"].pre[0].column_names + assert len(con_cols) == 3 + assert all(c in outputs for c in con_cols) + assert all(outputs[c].shape[1] == 4 for c in con_cols) + + seq_cols = input_block["sequence"].pre[0].column_names + assert len(seq_cols) == 7 + assert all(c in outputs for c in seq_cols) + + def test_stack_context(self, sequence_testing_data): + schema = sequence_testing_data.schema + input_block = mm.TabularInputBlock(schema, init=mm.stack_context(model_dim=4)) + + batch = mm.Batch.sample_from(sequence_testing_data, batch_size=10) + padded = mm.TabularPadding(schema)(batch) + outputs = module_utils.module_test(input_block, padded) + + assert list(outputs.keys()) == ["context", "sequence"] + assert outputs["context"].shape == (10, 3, 4) + assert outputs["sequence"].shape == (10, 4, 29) + def test_init_agg(self): input_block = mm.TabularInputBlock(self.schema, init="defaults", agg="concat") outputs = module_utils.module_test(input_block, self.batch) @@ -68,27 +97,28 @@ def test_extract_route_two_tower(self): "item_recency", "item_genres", } - assert set(mm.schema.input(towers).column_names) == input_cols - assert mm.schema.output(towers).column_names == ["user", "item"] - categorical = towers.select(Tags.CATEGORICAL) outputs = module_utils.module_test(towers, self.batch) + assert set(mm.input_schema(towers).column_names) == input_cols + assert mm.output_schema(towers).column_names == ["user", "item"] + + categorical = towers.select(Tags.CATEGORICAL) assert mm.schema.extract(towers, Tags.CATEGORICAL)[1] == categorical - assert set(mm.schema.input(towers).column_names) == input_cols - assert mm.schema.output(towers).column_names == ["user", "item"] + assert set(mm.input_schema(towers).column_names) == input_cols + assert mm.output_schema(towers).column_names == ["user", "item"] outputs = towers(self.batch.features) assert outputs["user"].shape == (10, 10) assert outputs["item"].shape == (10, 10) new_inputs, route = mm.schema.extract(towers, Tags.USER) - assert mm.schema.output(new_inputs).column_names == ["user", "item"] + assert mm.output_schema(new_inputs).column_names == ["user", "item"] assert "user" in new_inputs.branches assert new_inputs.branches["user"][0].select_keys.column_names == ["user"] assert "user" in route.branches - assert mm.schema.output(route).select_by_tag(Tags.EMBEDDING).column_names == ["user"] + assert mm.output_schema(route).select_by_tag(Tags.EMBEDDING).column_names == ["user"] def test_extract_route_embeddings(self): input_block = mm.TabularInputBlock(self.schema, init="defaults", agg="concat") @@ -97,7 +127,7 @@ def test_extract_route_embeddings(self): assert outputs.shape == (10, 107) no_embs, emb_route = mm.schema.extract(input_block, Tags.CATEGORICAL) - output_schema = mm.schema.output(emb_route) + output_schema = mm.output_schema(emb_route) assert len(output_schema.select_by_tag(Tags.USER)) == 3 assert len(output_schema.select_by_tag(Tags.ITEM)) == 3 @@ -131,3 +161,21 @@ def test_extract_double_nesting(self): no_user_id, user_id_route = mm.schema.extract(input_block, Tags.USER_ID) assert no_user_id + + def test_nesting(self): + input_block = mm.TabularInputBlock(self.schema) + input_block.add_route( + lambda schema: schema, + mm.TabularInputBlock(init="defaults"), + ) + outputs = module_utils.module_test(input_block, self.batch) + + for name in mm.schema.select(self.schema, Tags.CONTINUOUS).column_names: + assert name in outputs + + for name in mm.schema.select(self.schema, Tags.CATEGORICAL).column_names: + assert name in outputs + assert outputs[name].shape == ( + 10, + infer_embedding_dim(self.schema.select_by_name(name)), + ) diff --git a/tests/unit/torch/models/test_base.py b/tests/unit/torch/models/test_base.py index d51589b8f1..0b86fa8fc9 100644 --- a/tests/unit/torch/models/test_base.py +++ b/tests/unit/torch/models/test_base.py @@ -44,13 +44,17 @@ def test_init_default(self): model = mm.Model(mm.Block(), nn.Linear(10, 10)) assert isinstance(model, mm.Model) assert len(model) == 2 - assert model.optimizer is torch.optim.Adam - assert isinstance(model.configure_optimizers(), torch.optim.Adam) + assert isinstance(model.configure_optimizers()[0], torch.optim.Adam) - def test_init_optimizer(self): - optimizer = torch.optim.SGD - model = mm.Model(mm.Block(), mm.Block(), optimizer=optimizer) - assert model.optimizer is torch.optim.SGD + def test_init_optimizer_and_scheduler(self): + model = mm.Model(mm.MLPBlock([4, 4])) + model.initialize(mm.Batch(torch.rand(2, 2))) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99) + opt, sched = model.configure_optimizers(optimizer, scheduler) + assert opt == [optimizer] + assert sched == [scheduler] def test_pre_and_pre(self): inputs = torch.tensor([[1, 2], [3, 4]]) @@ -94,6 +98,9 @@ def test_initialize_raises_error(self): with pytest.raises(RuntimeError, match="Unexpected input type"): model.initialize(inputs) + with pytest.raises(ValueError): + mm.Model(mm.Block(), pre=mm.Block()) + def test_script(self): model = mm.Model(mm.Block(), mm.Block()) inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) @@ -125,10 +132,10 @@ def test_training_step_values(self): loss = model.training_step((features, targets), 0) (weights, bias) = model.parameters() expected_outputs = nn.Sigmoid()(torch.matmul(features["feature"], weights.T) + bias) - expected_loss = nn.BCEWithLogitsLoss()(expected_outputs, targets["target"]) + expected_loss = nn.BCELoss()(expected_outputs, targets["target"]) assert torch.allclose(loss, expected_loss) - def test_training_step_with_dataloader(self): + def test_step_with_dataloader(self): model = mm.Model( mm.Concat(), mm.BinaryOutput(ColumnSchema("target")), @@ -144,8 +151,11 @@ def test_training_step_with_dataloader(self): loss = model.training_step(batch, 0) assert loss > 0.0 + assert torch.equal( + model.validation_step(batch, 0)["loss"], model.test_step(batch, 0)["loss"] + ) - def test_training_step_with_batch(self): + def test_step_with_batch(self): model = mm.Model( mm.Concat(), mm.BinaryOutput(ColumnSchema("target")), @@ -156,6 +166,9 @@ def test_training_step_with_batch(self): model.initialize(batch) loss = model.training_step(batch, 0) assert loss > 0.0 + assert torch.equal( + model.validation_step(batch, 0)["loss"], model.test_step(batch, 0)["loss"] + ) def test_training_step_missing_output(self): model = mm.Model(mm.Block()) @@ -191,7 +204,7 @@ def test_output_schema(self): "b": torch.tensor([[5.0, 6.0], [7.0, 8.0]]), } outputs = mm.schema.trace(model, inputs) - schema = mm.schema.output(model) + schema = mm.output_schema(model) for name in outputs: assert name in schema.column_names assert schema[name].dtype.name == str(outputs[name].dtype).split(".")[-1] @@ -199,7 +212,7 @@ def test_output_schema(self): def test_no_output_schema(self): model = mm.Model(PlusOne()) with pytest.raises(ValueError, match="Could not get output schema of PlusOne()"): - mm.schema.output(model) + mm.output_schema(model) def test_train_classification_with_lightning_trainer(self, music_streaming_data, batch_size=16): schema = music_streaming_data.schema.select_by_name( @@ -214,10 +227,7 @@ def test_train_classification_with_lightning_trainer(self, music_streaming_data, ) trainer = pl.Trainer(max_epochs=1, devices=1) - - with Loader(music_streaming_data, batch_size=batch_size) as loader: - model.initialize(loader) - trainer.fit(model, loader) + trainer.fit(model, Loader(music_streaming_data, batch_size=batch_size)) assert trainer.logged_metrics["train_loss"] > 0.0 assert trainer.num_training_batches == 7 # 100 rows // 16 per batch + 1 for last batch @@ -226,13 +236,38 @@ def test_train_classification_with_lightning_trainer(self, music_streaming_data, _ = module_utils.module_test(model, batch) +class TestMultiLoader: + def test_train_dataset(self, music_streaming_data): + multi_loader = mm.MultiLoader(music_streaming_data) + assert multi_loader.train_dataloader() is multi_loader.loader_train + + def test_train_loader(self, music_streaming_data): + multi_loader = mm.MultiLoader(Loader(music_streaming_data, 2)) + assert multi_loader.train_dataloader() is multi_loader.loader_train + + def test_valid_dataloader(self, music_streaming_data): + multi_loader = mm.MultiLoader(music_streaming_data, music_streaming_data) + assert multi_loader.val_dataloader() is multi_loader.loader_valid + + def test_test_dataloader(self, music_streaming_data): + multi_loader = mm.MultiLoader(*([music_streaming_data] * 3)) + assert multi_loader.test_dataloader() is multi_loader.loader_test + + def test_teardown(self, music_streaming_data): + multi_loader = mm.MultiLoader(*([music_streaming_data] * 3)) + multi_loader.teardown(None) + assert not hasattr(multi_loader, "loader_train") + assert not hasattr(multi_loader, "loader_valid") + assert not hasattr(multi_loader, "loader_test") + + class TestComputeLoss: def test_tensor_inputs(self): - predictions = torch.randn(2, 1) + predictions = torch.sigmoid(torch.randn(2, 1)) targets = torch.randint(2, (2, 1), dtype=torch.float32) model_outputs = [mm.BinaryOutput(ColumnSchema("a"))] results = compute_loss(predictions, targets, model_outputs) - expected_loss = nn.BCEWithLogitsLoss()(predictions, targets) + expected_loss = nn.BCELoss()(predictions, targets) expected_auroc = AUROC(task="binary")(predictions, targets) expected_acc = Accuracy(task="binary")(predictions, targets) expected_prec = Precision(task="binary")(predictions, targets) @@ -253,48 +288,49 @@ def test_tensor_inputs(self): assert torch.allclose(results["binary_recall"], expected_rec) def test_no_metrics(self): - predictions = torch.randn(2, 1) + predictions = torch.sigmoid(torch.randn(2, 1)) targets = torch.randint(2, (2, 1), dtype=torch.float32) model_outputs = [mm.BinaryOutput(ColumnSchema("a"))] results = compute_loss(predictions, targets, model_outputs, compute_metrics=False) assert sorted(results.keys()) == ["loss"] def test_dict_inputs(self): - predictions = {"a": torch.randn(2, 1)} + outputs = mm.ParallelBlock({"a": mm.BinaryOutput(ColumnSchema("a"))}) + predictions = outputs(torch.randn(2, 1)) targets = {"a": torch.randint(2, (2, 1), dtype=torch.float32)} - model_outputs = (mm.BinaryOutput(ColumnSchema("a")),) - results = compute_loss(predictions, targets, model_outputs) - expected_loss = nn.BCEWithLogitsLoss()(predictions["a"], targets["a"]) + + results = compute_loss(predictions, targets, outputs.find(mm.ModelOutput)) + expected_loss = nn.BCELoss()(predictions["a"], targets["a"]) assert torch.allclose(results["loss"], expected_loss) def test_mixed_inputs(self): predictions = {"a": torch.randn(2, 1)} targets = torch.randint(2, (2, 1), dtype=torch.float32) - model_outputs = (mm.BinaryOutput(ColumnSchema("a")),) + model_outputs = (mm.RegressionOutput(ColumnSchema("a")),) results = compute_loss(predictions, targets, model_outputs) - expected_loss = nn.BCEWithLogitsLoss()(predictions["a"], targets) + expected_loss = nn.MSELoss()(predictions["a"], targets) assert torch.allclose(results["loss"], expected_loss) def test_single_model_output(self): predictions = {"foo": torch.randn(2, 1)} targets = {"foo": torch.randint(2, (2, 1), dtype=torch.float32)} - model_outputs = [mm.BinaryOutput(ColumnSchema("foo"))] + model_outputs = [mm.RegressionOutput(ColumnSchema("foo"))] results = compute_loss(predictions, targets, model_outputs) - expected_loss = nn.BCEWithLogitsLoss()(predictions["foo"], targets["foo"]) + expected_loss = nn.MSELoss()(predictions["foo"], targets["foo"]) assert torch.allclose(results["loss"], expected_loss) def test_tensor_input_no_targets(self): predictions = torch.randn(2, 1) - binary_output = mm.BinaryOutput(ColumnSchema("foo")) + binary_output = mm.RegressionOutput(ColumnSchema("foo")) results = compute_loss(predictions, None, (binary_output,)) - expected_loss = nn.BCEWithLogitsLoss()(predictions, torch.zeros(2, 1)) + expected_loss = nn.MSELoss()(predictions, torch.zeros(2, 1)) assert torch.allclose(results["loss"], expected_loss) def test_dict_input_no_targets(self): predictions = {"foo": torch.randn(2, 1)} - binary_output = mm.BinaryOutput(ColumnSchema("foo")) + binary_output = mm.RegressionOutput(ColumnSchema("foo")) results = compute_loss(predictions, None, (binary_output,)) - expected_loss = nn.BCEWithLogitsLoss()(predictions["foo"], torch.zeros(2, 1)) + expected_loss = nn.MSELoss()(predictions["foo"], torch.zeros(2, 1)) assert torch.allclose(results["loss"], expected_loss) def test_no_target_raises_error(self): diff --git a/tests/unit/torch/models/test_ranking.py b/tests/unit/torch/models/test_ranking.py index 0fb463e0ef..1a9e5678e2 100644 --- a/tests/unit/torch/models/test_ranking.py +++ b/tests/unit/torch/models/test_ranking.py @@ -36,3 +36,34 @@ def test_train_dlrm_with_lightning_loader( batch = sample_batch(music_streaming_data, batch_size) _ = module_utils.module_test(model, batch) + + +class TestDCNModel: + @pytest.mark.parametrize("depth", [1, 2]) + @pytest.mark.parametrize("stacked", [True, False]) + @pytest.mark.parametrize("deep_block", [None, mm.MLPBlock([4, 2])]) + def test_train_dcn_with_lightning_trainer( + self, + music_streaming_data, + depth, + stacked, + deep_block, + batch_size=16, + ): + schema = music_streaming_data.schema.select_by_name( + ["item_id", "user_id", "user_age", "item_genres", "click"] + ) + music_streaming_data.schema = schema + + model = mm.DCNModel(schema, depth=depth, deep_block=deep_block, stacked=stacked) + + trainer = pl.Trainer(max_epochs=1, devices=1) + + with Loader(music_streaming_data, batch_size=batch_size) as train_loader: + model.initialize(train_loader) + trainer.fit(model, train_loader) + + assert trainer.logged_metrics["train_loss"] > 0.0 + + batch = sample_batch(music_streaming_data, batch_size) + _ = module_utils.module_test(model, batch) diff --git a/tests/unit/torch/outputs/sampling/__init__.py b/tests/unit/torch/outputs/sampling/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/torch/outputs/sampling/test_in_batch.py b/tests/unit/torch/outputs/sampling/test_in_batch.py new file mode 100644 index 0000000000..d5d5065935 --- /dev/null +++ b/tests/unit/torch/outputs/sampling/test_in_batch.py @@ -0,0 +1,23 @@ +import torch + +from merlin.models.torch.outputs.sampling import InBatchNegativeSampler +from merlin.models.torch.utils import module_utils + + +class TestInBatchNegativeSampler: + def test_forward(self): + # Create a sample input tensor + positive = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + positive_id = torch.tensor([0, 1]) + + # Instantiate the InBatchNegativeSampler + sampler = InBatchNegativeSampler() + + # Call the forward method + output_positive, output_positive_id = module_utils.module_test( + sampler, positive, positive_id=positive_id + ) + + # Check if the output_positive and output_positive_id tensors are correct + assert torch.equal(output_positive, positive) + assert torch.equal(output_positive_id, positive_id) diff --git a/tests/unit/torch/outputs/sampling/test_popularity.py b/tests/unit/torch/outputs/sampling/test_popularity.py new file mode 100644 index 0000000000..8f6b35a033 --- /dev/null +++ b/tests/unit/torch/outputs/sampling/test_popularity.py @@ -0,0 +1,109 @@ +import pytest +import torch + +from merlin.models.torch.outputs.sampling.popularity import ( + LogUniformSampler, + PopularityBasedSampler, +) + + +class TestLogUniformSampler: + def test_init(self): + range_max = 1000 + n_sample = 100 + sampler = LogUniformSampler(n_sample, range_max) + + assert sampler.max_id == range_max + assert sampler.max_n_samples == n_sample + assert sampler.dist.size(0) == range_max + assert sampler.unique_sampling_dist.size(0) == range_max + + def test_sample(self): + range_max = 1000 + n_sample = 100 + sampler = LogUniformSampler(n_sample, range_max) + + labels = torch.tensor([10, 50, 150]) + neg_samples, true_log_probs, samp_log_probs = sampler.sample(labels) + + assert true_log_probs.size() == labels.size() + assert samp_log_probs.size()[0] <= 2 * n_sample + assert neg_samples.size()[0] <= 2 * n_sample + + @pytest.mark.parametrize("range_max, n_sample", [(1000, 100), (5000, 250), (10000, 500)]) + def test_dist_sum(self, range_max, n_sample): + sampler = LogUniformSampler(n_sample, range_max) + + assert torch.isclose(sampler.dist.sum(), torch.tensor(1.0), atol=1e-6) + + def test_init_exceptions(self): + with pytest.raises(ValueError, match="n_sample must be a positive integer."): + LogUniformSampler(-100, 1000) + + with pytest.raises(ValueError, match="max_id must be a positive integer"): + LogUniformSampler(100, -1000) + + def test_sample_exceptions(self): + range_max = 1000 + n_sample = 100 + sampler = LogUniformSampler(n_sample, range_max) + + with pytest.raises(TypeError, match="Labels must be a torch.Tensor."): + sampler.sample([10, 50, 150]) + + with pytest.raises(ValueError, match="Labels must be a tensor of dtype long."): + sampler.sample(torch.tensor([10, 50, 150], dtype=torch.float32)) + + with pytest.raises(ValueError, match="Labels must be a tensor of dtype long."): + sampler.sample(torch.tensor([])) + + with pytest.raises(ValueError, match="All label values must be within the range"): + sampler.sample(torch.tensor([-1, 50, 150])) + + with pytest.raises(ValueError, match="All label values must be within the range"): + sampler.sample(torch.tensor([10, 50, 150, 2000])) + + +class TestPopularityBasedSampler: + def test_init_defaults(self): + sampler = PopularityBasedSampler() + assert sampler.labels.dtype == torch.int64 + assert sampler.labels.shape == torch.Size([1, 1]) + assert sampler.max_num_samples == 10 + + def test_init_custom_values(self): + sampler = PopularityBasedSampler(max_num_samples=20) + assert sampler.max_num_samples == 20 + + def test_forward_raises_runtime_error(self): + sampler = PopularityBasedSampler() + with pytest.raises(RuntimeError): + sampler(torch.tensor([1.0]), torch.tensor([1])) + + def test_forward(self): + class MockToCall: + def embedding_lookup(self, ids): + return torch.tensor([42.0]) + + num_classes = 1000 + + sampler = PopularityBasedSampler() + sampler.set_to_call(MockToCall()) + + negative, negative_id = sampler(torch.tensor([1.0]), torch.tensor([1])) + + assert isinstance(negative, torch.Tensor) + assert isinstance(negative_id, torch.Tensor) + + def test_sampler_property(self): + class MockToCall: + num_classes = 1000 + + sampler = PopularityBasedSampler() + sampler.set_to_call(MockToCall()) + + log_uniform_sampler = sampler.sampler + + assert isinstance(log_uniform_sampler, LogUniformSampler) + assert log_uniform_sampler.max_id == MockToCall.num_classes + assert log_uniform_sampler.max_n_samples == sampler.max_num_samples diff --git a/tests/unit/torch/outputs/test_base.py b/tests/unit/torch/outputs/test_base.py index a156a39a13..9f38c52033 100644 --- a/tests/unit/torch/outputs/test_base.py +++ b/tests/unit/torch/outputs/test_base.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import numpy as np +import pytest import torch from torch import nn from torchmetrics import AUROC, Accuracy import merlin.models.torch as mm from merlin.models.torch.utils import module_utils -from merlin.schema import ColumnSchema, Schema, Tags class TestModelOutput: @@ -31,8 +30,8 @@ def test_init(self): assert isinstance(model_output, mm.ModelOutput) assert model_output.loss is loss - assert model_output.metrics == () - assert model_output.output_schema == Schema() + assert model_output.metrics is None + assert not mm.schema.output_schema(model_output) def test_identity(self): block = mm.Block() @@ -47,20 +46,11 @@ def test_identity(self): def test_setup_metrics(self): block = mm.Block() loss = nn.BCEWithLogitsLoss() - metrics = (Accuracy(task="binary"), AUROC(task="binary")) + metrics = [Accuracy(task="binary"), AUROC(task="binary")] model_output = mm.ModelOutput(block, loss=loss, metrics=metrics) assert model_output.metrics == metrics - def test_setup_schema(self): - block = mm.Block() - loss = nn.BCEWithLogitsLoss() - schema = ColumnSchema("feature", dtype=np.int32, tags=[Tags.CONTINUOUS]) - model_output = mm.ModelOutput(block, loss=loss, schema=schema) - - assert isinstance(model_output.output_schema, Schema) - assert model_output.output_schema.first == schema - def test_eval_resets_target(self): block = mm.Block() loss = nn.BCEWithLogitsLoss() @@ -71,3 +61,26 @@ def test_eval_resets_target(self): assert torch.equal(model_output.target, torch.ones(1)) model_output.eval() assert torch.equal(model_output.target, torch.zeros(1)) + + def test_copy(self): + block = mm.Block() + loss = nn.BCEWithLogitsLoss() + metrics = [Accuracy(task="multiclass", num_classes=11)] + model_output = mm.ModelOutput(block, loss=loss, metrics=metrics) + + model_copy = model_output.copy() + assert model_copy.loss is not loss + assert isinstance(model_copy.loss, nn.BCEWithLogitsLoss) + assert model_copy.metrics[0] is not metrics[0] + assert model_copy.metrics[0].__class__.__name__ == "MulticlassAccuracy" + assert model_copy.metrics[0].num_classes == 11 + + @pytest.mark.parametrize("logits_temperature", [0.1, 0.9]) + def test_logits_temperature_scaler(self, logits_temperature): + block = mm.Block() + model_output = mm.ModelOutput(block, logits_temperature=logits_temperature) + inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + + outputs = module_utils.module_test(model_output, inputs) + + assert torch.allclose(inputs / logits_temperature, outputs) diff --git a/tests/unit/torch/outputs/test_classification.py b/tests/unit/torch/outputs/test_classification.py index 755d465350..e402ab2c4a 100644 --- a/tests/unit/torch/outputs/test_classification.py +++ b/tests/unit/torch/outputs/test_classification.py @@ -15,14 +15,16 @@ # import pytest import torch +import torchmetrics as tm from torch import nn from torchmetrics import AUROC, Accuracy, Precision, Recall from torchmetrics.classification import BinaryF1Score import merlin.dtypes as md import merlin.models.torch as mm +from merlin.models.torch.outputs.classification import CategoricalTarget, EmbeddingTablePrediction from merlin.models.torch.utils import module_utils -from merlin.schema import ColumnSchema, Schema +from merlin.schema import ColumnSchema, Schema, Tags class TestBinaryOutput: @@ -30,14 +32,15 @@ def test_init(self): binary_output = mm.BinaryOutput() assert isinstance(binary_output, mm.BinaryOutput) - assert isinstance(binary_output.loss, nn.BCEWithLogitsLoss) + assert isinstance(binary_output.loss, nn.BCELoss) assert binary_output.metrics == [ Accuracy(task="binary"), AUROC(task="binary"), Precision(task="binary"), Recall(task="binary"), ] - assert binary_output.output_schema == Schema() + with pytest.raises(ValueError): + mm.output_schema(binary_output) def test_identity(self): binary_output = mm.BinaryOutput() @@ -47,7 +50,7 @@ def test_identity(self): assert outputs.shape == (3, 1) - def test_setup_schema(self): + def test_output_schema(self): schema = ColumnSchema("foo") binary_output = mm.BinaryOutput(schema=schema) @@ -57,8 +60,9 @@ def test_setup_schema(self): assert binary_output.output_schema.first.properties["domain"]["min"] == 0 assert binary_output.output_schema.first.properties["domain"]["max"] == 1 - with pytest.raises(ValueError): - binary_output.setup_schema(Schema(["a", "b"])) + def test_error_multiple_columns(self): + with pytest.raises(ValueError, match="Schema must contain exactly one column"): + mm.BinaryOutput(schema=Schema(["a", "b"])) def test_custom_loss(self): binary_output = mm.BinaryOutput(loss=nn.BCELoss()) @@ -83,3 +87,193 @@ def test_cutom_metrics(self): binary_output.metrics[0](outputs, targets), BinaryF1Score()(outputs, targets), ) + + +class TestCategoricalOutput: + def test_init(self): + int_domain_max = 3 + schema = ( + ColumnSchema("foo") + .with_dtype(md.int32) + .with_properties({"domain": {"name": "bar", "min": 0, "max": int_domain_max}}) + ) + categorical_output = mm.CategoricalOutput(schema) + + assert isinstance(categorical_output, mm.CategoricalOutput) + assert isinstance(categorical_output.loss, nn.CrossEntropyLoss) + assert isinstance(categorical_output.metrics[0], tm.RetrievalHitRate) + assert isinstance(categorical_output.metrics[1], tm.RetrievalNormalizedDCG) + assert isinstance(categorical_output.metrics[2], tm.RetrievalPrecision) + assert isinstance(categorical_output.metrics[3], tm.RetrievalRecall) + + output_schema = categorical_output[0].output_schema.first + assert output_schema.dtype == md.float32 + assert output_schema.properties["domain"]["min"] == 0 + assert output_schema.properties["domain"]["max"] == 1 + assert ( + output_schema.properties["value_count"]["min"] + == output_schema.properties["value_count"]["max"] + == int_domain_max + 1 + ) + assert mm.output_schema(categorical_output) == categorical_output[0].output_schema + + def test_called_with_schema(self): + int_domain_max = 3 + schema = ( + ColumnSchema("foo") + .with_dtype(md.int32) + .with_properties({"domain": {"name": "bar", "min": 0, "max": int_domain_max}}) + ) + categorical_output = mm.CategoricalOutput(schema) + + inputs = torch.randn(3, 2) + outputs = module_utils.module_test(categorical_output, inputs) + + num_classes = int_domain_max + 1 + assert outputs.shape == (3, num_classes) + + def test_weight_tying(self): + embedding_dim = 8 + int_domain_max = 3 + schema = ( + ColumnSchema("foo") + .with_dtype(md.int32) + .with_properties({"domain": {"name": "bar", "min": 0, "max": int_domain_max}}) + ) + table = mm.EmbeddingTable(embedding_dim, schema) + categorical_output = mm.CategoricalOutput.with_weight_tying(table) + + inputs = torch.randn(3, embedding_dim) + outputs = module_utils.module_test(categorical_output, inputs) + + num_classes = int_domain_max + 1 + assert outputs.shape == (3, num_classes) + + cat_output = mm.CategoricalOutput(schema).tie_weights(table) + assert isinstance(cat_output[0], EmbeddingTablePrediction) + + def test_invalid_type_error(self): + with pytest.raises(ValueError, match="Target must be a ColumnSchema or Schema"): + mm.CategoricalOutput("invalid to_call") + + def test_multiple_column_schema_error(self, item_id_col_schema, user_id_col_schema): + schema = Schema([item_id_col_schema]) + assert len(schema) == 1 + _ = mm.CategoricalOutput(schema) + + schema_with_two_columns = schema + Schema([user_id_col_schema]) + assert len(schema_with_two_columns) == 2 + with pytest.raises(ValueError, match="must contain exactly one"): + _ = mm.CategoricalOutput(schema_with_two_columns) + + +class TestCategoricalTarget: + def test_init(self, user_id_col_schema): + schema = Schema([user_id_col_schema]) + + # Test with ColumnSchema + model = CategoricalTarget(user_id_col_schema) + assert model.num_classes == user_id_col_schema.int_domain.max + 1 + assert isinstance(model.linear, nn.LazyLinear) + + # Test with Schema + model = CategoricalTarget(feature=schema) + assert model.num_classes == user_id_col_schema.int_domain.max + 1 + assert isinstance(model.linear, nn.LazyLinear) + + def test_forward(self, user_id_col_schema): + model = CategoricalTarget(feature=user_id_col_schema) + + inputs = torch.randn(5, 11) + output = model(inputs) + + assert output.shape == (5, 21) + + def test_forward_with_activation(self, user_id_col_schema): + model = CategoricalTarget(feature=user_id_col_schema, activation=nn.ReLU()) + + inputs = torch.randn(5, 11) + output = model(inputs) + + assert output.shape == (5, 21) + assert torch.all(output >= 0) + + def test_embedding_lookup(self, user_id_col_schema): + model = CategoricalTarget(feature=user_id_col_schema) + + model(torch.randn(5, 11)) # initialize the embedding table + input_indices = torch.tensor([1, 5, 10]) + hidden_vectors = model.embedding_lookup(input_indices) + + assert hidden_vectors.shape == (3, 11) + assert model.embeddings().shape == (11, 21) + + def test_forward_model_output(self): + int_domain_max = 3 + schema = ( + ColumnSchema("foo") + .with_dtype(md.int32) + .with_properties({"domain": {"name": "bar", "min": 0, "max": int_domain_max}}) + ) + target = mm.CategoricalTarget(schema) + categorical_output = mm.ModelOutput(target, loss=nn.CrossEntropyLoss()) + assert mm.output_schema(categorical_output).column_names == ["foo"] + + inputs = torch.randn(3, 2) + outputs = module_utils.module_test(categorical_output, inputs) + num_classes = int_domain_max + 1 + assert outputs.shape == (3, num_classes) + + +class TestEmbeddingTablePrediction: + def test_init_multiple_int_domains(self, user_id_col_schema, item_id_col_schema): + input_block = mm.TabularInputBlock(Schema([user_id_col_schema, item_id_col_schema])) + input_block.add_route(Tags.CATEGORICAL, mm.EmbeddingTable(10)) + table = mm.schema.select(input_block, Tags.USER_ID).leaf() + + with pytest.raises(ValueError): + EmbeddingTablePrediction(table) + + with pytest.raises(ValueError): + EmbeddingTablePrediction.with_weight_tying(input_block) + + with pytest.raises(ValueError): + EmbeddingTablePrediction.with_weight_tying(input_block, "a") + + with pytest.raises(ValueError): + EmbeddingTablePrediction.with_weight_tying(input_block, Tags.CATEGORICAL) + + assert isinstance(EmbeddingTablePrediction(table, Tags.USER_ID), EmbeddingTablePrediction) + assert isinstance( + EmbeddingTablePrediction.with_weight_tying(input_block, Tags.USER_ID), + EmbeddingTablePrediction, + ) + + def test_forward(self, user_id_col_schema): + input_block = mm.TabularInputBlock( + Schema([user_id_col_schema]), init="defaults", agg="concat" + ) + prediction = EmbeddingTablePrediction.with_weight_tying(input_block, Tags.USER_ID) + + inputs = torch.randn(5, 8) + output = module_utils.module_test(prediction, inputs) + + assert output.shape == (5, 21) + + def test_embedding_lookup(self): + embedding_dim = 8 + int_domain_max = 3 + schema = ( + ColumnSchema("foo") + .with_dtype(md.int32) + .with_properties({"domain": {"name": "bar", "min": 0, "max": int_domain_max}}) + ) + table = mm.EmbeddingTable(embedding_dim, schema) + model = mm.EmbeddingTablePrediction(table) + + batch_size = 16 + ids = torch.randint(0, int_domain_max, (batch_size,)) + outputs = model.embedding_lookup(ids) + + assert outputs.shape == (batch_size, embedding_dim) + assert model.embeddings().shape == (int_domain_max + 1, embedding_dim) diff --git a/tests/unit/torch/outputs/test_constrastive.py b/tests/unit/torch/outputs/test_constrastive.py new file mode 100644 index 0000000000..c616978f68 --- /dev/null +++ b/tests/unit/torch/outputs/test_constrastive.py @@ -0,0 +1,261 @@ +import pytest +import torch + +import merlin.models.torch as mm +from merlin.models.torch.outputs.classification import CategoricalTarget +from merlin.models.torch.outputs.contrastive import ( + ContrastiveOutput, + DotProduct, + rescore_false_negatives, +) +from merlin.models.torch.utils.module_utils import module_test +from merlin.schema import Schema + + +class TestContrastiveOutput: + def test_initialize_from_schema(self, item_id_col_schema, user_id_col_schema): + contrastive = ContrastiveOutput() + + dot = ContrastiveOutput(schema=Schema([item_id_col_schema, user_id_col_schema])) + assert isinstance(dot.to_call, DotProduct) + + target = ContrastiveOutput(schema=Schema([item_id_col_schema])) + assert isinstance(target.to_call, CategoricalTarget) + + with pytest.raises(ValueError): + contrastive.initialize_from_schema(1) + + with pytest.raises(ValueError): + contrastive.initialize_from_schema(Schema(["a", "b", "c"])) + + def test_outputs_without_downscore(self, item_id_col_schema): + contrastive = ContrastiveOutput(item_id_col_schema, downscore_false_negatives=False) + + query = torch.tensor([[0.1, 0.2], [0.3, 0.4]]) + positive = torch.tensor([[0.5, 0.6], [0.7, 0.8]]) + negative = torch.tensor([[0.9, 1.0], [1.1, 1.2], [1.3, 1.4]]) + + id_tensor = torch.tensor(1) + out = contrastive.contrastive_outputs(query, positive, negative, id_tensor, id_tensor) + + expected_out = torch.tensor( + [[0.1700, 0.2900, 0.3500, 0.4100], [0.5300, 0.6700, 0.8100, 0.9500]] + ) + expected_target = torch.tensor([[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]) + + assert torch.allclose(out, expected_out, atol=1e-4) + assert torch.equal(contrastive.target, expected_target) + + def test_outputs_with_rescore_false_negatives(self, item_id_col_schema): + contrastive = ContrastiveOutput(item_id_col_schema, false_negative_score=-100.0) + + query = torch.tensor([[0.1, 0.2]]) + positive = torch.tensor([[0.5, 0.6]]) + negative = torch.tensor([[0.5, 0.6], [0.9, 1.0]]) + positive_id = torch.tensor([[0]]) + negative_id = torch.tensor([[0, 1]]) + + out = contrastive.contrastive_outputs(query, positive, negative, positive_id, negative_id) + + # Explanation: + # 1. positive_scores = dot(query, positive) = dot([0.1, 0.2], [0.5, 0.6]) = 0.17 + # 2. negative_scores = matmul(query, negative.T) + # = matmul([[0.1, 0.2]], [[0.5, 0.9], [0.6, 1.0]]) = [[0.17, 0.29]] + # 3. Since the first negative sample is a false negative (its id is in positive_id), + # we downscore it to -100.0 + # 4. The final output is a concatenation of the positive_scores and the + # rescored negative_scores: [[0.17, -100.0, 0.29]] + + expected_out = torch.tensor([[0.17, -100.0, 0.29]]) + expected_target = torch.tensor([[1.0, 0.0, 0.0]]) + + assert torch.allclose(out, expected_out, atol=1e-4) + assert torch.equal(contrastive.target, expected_target) + + def test_outputs_raises_runtime_error(self, item_id_col_schema): + contrastive = ContrastiveOutput(item_id_col_schema) + + query = torch.tensor([[0.1, 0.2], [0.3, 0.4]]) + positive = torch.tensor([[0.5, 0.6], [0.7, 0.8]]) + negative = torch.tensor([[0.9, 1.0], [1.1, 1.2], [1.3, 1.4]]) + positive_id = torch.tensor([0, 1]) + negative_id = torch.tensor([1, 2, 3]) + + with pytest.raises(RuntimeError): + contrastive.contrastive_outputs(query, positive, negative, positive_id) + + with pytest.raises(RuntimeError): + contrastive.contrastive_outputs(query, positive, negative, negative_id=negative_id) + + def test_call_categorical_target(self, item_id_col_schema): + model = mm.Block( + mm.EmbeddingTable(10, Schema([item_id_col_schema])), + mm.Concat(), + ContrastiveOutput(item_id_col_schema), + ) + + item_id = torch.tensor([1, 2, 3]) + features = {"item_id": item_id} + batch = mm.Batch(features) + + outputs = module_test(model, features, batch=batch) + contrastive_outputs = model(features, batch=batch.replace(targets=item_id + 1)) + + assert contrastive_outputs.shape == (3, 4) + assert outputs.shape == (3, 11) + + self.assert_contrastive_outputs(contrastive_outputs, model[-1].target, outputs) + + def test_call_dot_product(self, user_id_col_schema, item_id_col_schema): + schema = Schema([user_id_col_schema, item_id_col_schema]) + model = mm.Block( + mm.EmbeddingTables(10, schema), + ContrastiveOutput(item_id_col_schema), + ) + + user_id = torch.tensor([1, 2, 3]) + item_id = torch.tensor([1, 2, 3]) + + data = {"user_id": user_id, "item_id": item_id} + contrastive_outputs = model(data, batch=mm.Batch(data)) + + model.eval() + outputs = module_test(model, data, batch=mm.Batch(data)) + + self.assert_contrastive_outputs(contrastive_outputs, model[-1].target, outputs) + + @pytest.mark.parametrize("negative_sampling", ["popularity", "in-batch"]) + def test_call_weight_tying(self, user_id_col_schema, item_id_col_schema, negative_sampling): + schema = Schema([user_id_col_schema, item_id_col_schema]) + embeddings = mm.EmbeddingTables(10, schema) + model = mm.Block( + embeddings, + mm.MLPBlock([10]), + ContrastiveOutput.with_weight_tying( + embeddings, item_id_col_schema, negative_sampling=negative_sampling + ), + ) + + user_id = torch.tensor([1, 2, 3]) + item_id = torch.tensor([1, 2, 3]) + + data = {"user_id": user_id, "item_id": item_id} + contrastive_outputs = model(data, batch=mm.Batch(data, targets=item_id + 1)) + + model.eval() + outputs = module_test(model, data, batch=mm.Batch(data)) + + assert outputs.shape == (3, 11) + + if negative_sampling == "popularity": + assert contrastive_outputs.shape[0] == 3 + assert contrastive_outputs.shape[1] >= 4 + else: + assert contrastive_outputs.shape == (3, 4) + self.assert_contrastive_outputs(contrastive_outputs, model[-1].target, outputs) + + def assert_contrastive_outputs(self, contrastive_outputs, targets, outputs): + assert contrastive_outputs.shape == (3, 4) + assert targets.shape == (3, 4) + assert not torch.equal(outputs, contrastive_outputs) + assert targets[:, 0].all().item() + assert torch.equal(targets[:, 1:], torch.zeros_like(targets[:, 1:])) + + +class Test_rescore_false_negatives: + def test_no_false_negatives(self): + positive_item_ids = torch.tensor([1, 3, 5]) + neg_samples_item_ids = torch.tensor([2, 4, 6]) + negative_scores = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + false_negatives_score = -100.0 + + rescored_neg_scores, valid_negatives_mask = rescore_false_negatives( + positive_item_ids, neg_samples_item_ids, negative_scores, false_negatives_score + ) + + assert torch.equal(rescored_neg_scores, negative_scores) + assert torch.equal( + valid_negatives_mask, torch.ones_like(valid_negatives_mask, dtype=torch.bool) + ) + + def test_with_false_negatives(self): + positive_item_ids = torch.tensor([1, 3, 5]) + neg_samples_item_ids = torch.tensor([1, 4, 5]) + negative_scores = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + false_negatives_score = -100.0 + + rescored_neg_scores, valid_negatives_mask = rescore_false_negatives( + positive_item_ids, neg_samples_item_ids, negative_scores, false_negatives_score + ) + + expected_rescored_neg_scores = torch.tensor( + [[-100.0, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, -100.0]] + ) + expected_valid_negatives_mask = torch.tensor( + [[False, True, True], [True, True, True], [True, True, False]], dtype=torch.bool + ) + + assert torch.equal(rescored_neg_scores, expected_rescored_neg_scores) + assert torch.equal(valid_negatives_mask, expected_valid_negatives_mask) + + def test_all_false_negatives(self): + positive_item_ids = torch.tensor([1, 3, 5]) + neg_samples_item_ids = torch.tensor([1, 3, 5]) + negative_scores = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + false_negatives_score = -100.0 + + rescored_neg_scores, valid_negatives_mask = rescore_false_negatives( + positive_item_ids, neg_samples_item_ids, negative_scores, false_negatives_score + ) + + expected_rescored_neg_scores = torch.tensor( + [[-100.0, 0.2, 0.3], [0.4, -100.0, 0.6], [0.7, 0.8, -100.0]] + ) + expected_valid_negatives_mask = torch.tensor( + [[False, True, True], [True, False, True], [True, True, False]], dtype=torch.bool + ) + + assert torch.equal(rescored_neg_scores, expected_rescored_neg_scores) + assert torch.equal(valid_negatives_mask, expected_valid_negatives_mask) + + +class TestDotProduct: + def test_less_than_two_inputs(self): + dp = DotProduct() + with pytest.raises(RuntimeError, match=r"DotProduct requires at least two inputs"): + dp({"candidate": torch.tensor([1, 2, 3])}) + + def test_no_query_name_more_than_two_inputs(self): + dp = DotProduct() + with pytest.raises( + RuntimeError, match=r"DotProduct requires query_name to be set when more than" + ): + dp( + { + "candidate": torch.tensor([1, 2, 3]), + "extra": torch.tensor([4, 5, 6]), + "another_extra": torch.tensor([7, 8, 9]), + } + ) + + def test_valid_dot_product_operation(self): + dp = DotProduct(query_name="query") + result = dp.forward( + {"query": torch.tensor([1.0, 2.0, 3.0]), "candidate": torch.tensor([4.0, 5.0, 6.0])} + ) + assert torch.allclose(result, torch.tensor([32.0]), atol=1e-4) + + def test_no_query_name_3_inputs(self): + dp = DotProduct(query_name=None) + with pytest.raises(RuntimeError): + dp( + { + "query": torch.tensor([1.0, 2.0, 3.0]), + "candidate": torch.tensor([1.0, 2.0, 3.0]), + "extra": torch.tensor([4.0, 5.0, 6.0]), + } + ) + + def test_should_apply_contrastive(self): + dp = DotProduct() + assert dp.should_apply_contrastive(None) == dp.training diff --git a/tests/unit/torch/outputs/test_regression.py b/tests/unit/torch/outputs/test_regression.py index f8537bca51..89a9e955fd 100644 --- a/tests/unit/torch/outputs/test_regression.py +++ b/tests/unit/torch/outputs/test_regression.py @@ -31,7 +31,6 @@ def test_init(self): assert isinstance(reg_output, mm.RegressionOutput) assert isinstance(reg_output.loss, nn.MSELoss) assert reg_output.metrics == [MeanSquaredError()] - assert reg_output.output_schema == Schema() def test_identity(self): reg_output = mm.RegressionOutput() @@ -41,7 +40,7 @@ def test_identity(self): assert outputs.shape == (3, 1) - def test_setup_schema(self): + def test_output_schema(self): schema = ColumnSchema("foo", dtype=md.int16) reg_output = mm.RegressionOutput(schema=schema) @@ -50,8 +49,9 @@ def test_setup_schema(self): assert reg_output.output_schema.first.dtype == md.float32 assert Tags.CONTINUOUS in reg_output.output_schema.first.tags - with pytest.raises(ValueError): - reg_output.setup_schema(Schema(["a", "b"])) + def test_error_multiple_columns(self): + with pytest.raises(ValueError, match="Schema must contain exactly one column"): + mm.RegressionOutput(schema=Schema(["a", "b"])) def test_default_loss(self): reg_output = mm.RegressionOutput() @@ -91,7 +91,7 @@ def test_default_metrics(self): def test_custom_metrics(self): reg_output = mm.RegressionOutput( - metrics=(MeanAbsoluteError(), MeanAbsolutePercentageError()) + metrics=[MeanAbsoluteError(), MeanAbsolutePercentageError()] ) features = torch.randn(3, 2) targets = torch.randn(3, 1) diff --git a/tests/unit/torch/outputs/test_tabular.py b/tests/unit/torch/outputs/test_tabular.py index 22ae132735..1c5fd3a453 100644 --- a/tests/unit/torch/outputs/test_tabular.py +++ b/tests/unit/torch/outputs/test_tabular.py @@ -1,9 +1,10 @@ import pytest import torch +import merlin.dtypes as md import merlin.models.torch as mm from merlin.models.torch.utils import module_utils -from merlin.schema import Schema, Tags +from merlin.schema import ColumnSchema, Schema, Tags class TestTabularOutputBlock: @@ -25,12 +26,56 @@ def test_init_defaults(self): assert "click" in outputs assert "like" in outputs + def test_init_defaults_with_binary_categorical(self): + test_schema = Schema( + [ + ColumnSchema("foo") + .with_dtype(md.int32) + .with_properties({"domain": {"name": "bar", "min": 0, "max": 1}}) + .with_tags([Tags.CATEGORICAL, Tags.TARGET]) + ] + ) + output_block = mm.TabularOutputBlock(test_schema, init="defaults") + + assert isinstance(output_block["foo"], mm.BinaryOutput) + + outputs = module_utils.module_test(output_block, torch.rand(10, 10)) + + assert "foo" in outputs + + def test_init_defaults_with_multiclass_categorical(self): + test_schema = Schema( + [ + ColumnSchema("foo") + .with_dtype(md.int32) + .with_properties({"domain": {"name": "bar", "min": 0, "max": 3}}) + .with_tags([Tags.CATEGORICAL, Tags.TARGET]) + ] + ) + output_block = mm.TabularOutputBlock(test_schema, init="defaults") + + assert isinstance(output_block["foo"], mm.CategoricalOutput) + + outputs = module_utils.module_test(output_block, torch.rand(10, 10)) + + assert "foo" in outputs + def test_exceptions(self): with pytest.raises(ValueError, match="not found"): mm.TabularOutputBlock(self.schema, init="not_found") def test_no_route_for_non_existent_tag(self): outputs = mm.TabularOutputBlock(self.schema) - outputs.add_route(Tags.CATEGORICAL) + outputs.add_route(Tags.CATEGORICAL, required=False) assert not outputs + + def test_nesting(self): + output_block = mm.TabularOutputBlock(self.schema) + output_block.add_route(Tags.TARGET, mm.TabularOutputBlock(init="defaults")) + + outputs = module_utils.module_test(output_block, torch.rand(10, 10)) + + assert "play_percentage" in outputs + assert "click" in outputs + assert "like" in outputs diff --git a/tests/unit/torch/test_batch.py b/tests/unit/torch/test_batch.py index 58e8c96ffb..2cab4f5043 100644 --- a/tests/unit/torch/test_batch.py +++ b/tests/unit/torch/test_batch.py @@ -186,6 +186,69 @@ def test_device(self): with pytest.raises(ValueError, match="Batch is empty"): empty_batch.device() + def test_flatten_as_dict(self): + features = {"feature1": torch.tensor([1, 2]), "feature2": torch.tensor([3, 4])} + targets = {"target1": torch.tensor([5, 6])} + lengths = {"length1": torch.tensor([7, 8])} + masks = {"mask1": torch.tensor([9, 10])} + sequences = Sequence(lengths, masks) + batch = Batch(features, targets, sequences) + + # without inputs + result = batch.flatten_as_dict(batch) + assert len(result) == 5 # 2 features, 1 target, 1 length, 1 mask + assert set(result.keys()) == set( + [ + "features.feature1", + "features.feature2", + "targets.target1", + "lengths.length1", + "masks.mask1", + ] + ) + + # with inputs + input_batch = Batch( + {"feature2": torch.tensor([11, 12])}, {"target1": torch.tensor([13, 14])}, sequences + ) + result = batch.flatten_as_dict(input_batch) + assert len(result) == 9 # input keys are considered + assert ( + len([k for k in result if k.startswith("inputs.")]) == 4 + ) # 1 feature, 1 target, 1 length, 1 mask + + def test_from_partial_dict(self): + features = {"feature1": torch.tensor([1, 2]), "feature2": torch.tensor([3, 4])} + targets = {"target1": torch.tensor([5, 6])} + lengths = {"length1": torch.tensor([7, 8])} + masks = {"mask1": torch.tensor([9, 10])} + sequences = Sequence(lengths, masks) + + batch = Batch(features, targets, sequences) + + partial_dict = { + "features.feature1": torch.tensor([11, 12]), + "targets.target1": torch.tensor([13, 14]), + "lengths.length1": torch.tensor([15, 16]), + "inputs.features.feature1": torch.tensor([1]), + "inputs.targets.target1": torch.tensor([1]), + "inputs.lengths.length1": torch.tensor([1]), + "inputs.masks.mask1": torch.tensor([1]), + } + + # create a batch from partial_dict + result = Batch.from_partial_dict(partial_dict, batch) + + assert result.features["feature1"].equal( + torch.tensor([11, 12]) + ) # updated from partial_dict + assert result.features["feature2"].equal(torch.tensor([3, 4])) # kept from batch + assert result.targets["target1"].equal(torch.tensor([13, 14])) # updated from partial_dict + assert result.sequences.lengths["length1"].equal( + torch.tensor([15, 16]) + ) # updated from partial_dict + assert "mask1" not in result.sequences.masks # removed from batch + class Test_sample_batch: def test_loader(self, music_streaming_data): diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index ea36aaa412..e535cdf1de 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import pytest import torch @@ -57,7 +57,7 @@ def test_identity(self): outputs = module_utils.module_test(block, inputs, batch=Batch(inputs)) assert torch.equal(inputs, outputs) - assert mm.schema.output(block) == mm.schema.output.tensors(inputs) + assert mm.output_schema(block) == mm.output_schema.tensors(inputs) def test_insertion(self): block = Block() @@ -158,7 +158,7 @@ def test_schema_tracking(self): inputs = torch.randn(1, 3) outputs = mm.schema.trace(pb, inputs) - schema = mm.schema.output(pb) + schema = mm.output_schema(pb) for name in outputs: assert name in schema.column_names @@ -258,9 +258,9 @@ def test_set_pre(self): def test_input_schema_pre(self): pb = ParallelBlock({"a": PlusOne(), "b": PlusOne()}) outputs = mm.schema.trace(pb, torch.randn(1, 3)) - input_schema = mm.schema.input(pb) + input_schema = mm.input_schema(pb) assert len(input_schema) == 1 - assert len(mm.schema.output(pb)) == 2 + assert len(mm.output_schema(pb)) == 2 assert len(outputs) == 2 pb2 = ParallelBlock({"a": PlusOne(), "b": PlusOne()}) @@ -270,8 +270,8 @@ def test_input_schema_pre(self): assert get_pre(pb2)[0] == pb pb2.append(pb) - assert input_schema == mm.schema.input(pb2) - assert mm.schema.output(pb2) == mm.schema.output(pb) + assert input_schema == mm.input_schema(pb2) + assert mm.output_schema(pb2) == mm.output_schema(pb) def test_leaf(self): block = ParallelBlock({"a": PlusOne()}) @@ -377,3 +377,60 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch. shortcut_propagate = ShortcutBlock(PlusOneShortcutTuple(), propagate_shortcut=True) with pytest.raises(RuntimeError): module_utils.module_test(shortcut_propagate, torch.rand(5, 5)) + + +class TestBatchBlock: + def test_forward_with_batch(self): + batch = Batch(torch.tensor([1, 2]), torch.tensor([3, 4])) + outputs = mm.BatchBlock()(batch) + + assert batch == outputs + + def test_forward_with_features(self): + feat = torch.tensor([1, 2]) + outputs = module_utils.module_test(mm.BatchBlock(), feat) + assert isinstance(outputs, mm.Batch) + assert torch.equal(outputs.feature(), feat) + + def test_forward_with_tuple(self): + feat, target = torch.tensor([1, 2]), torch.tensor([3, 4]) + outputs = module_utils.module_test(mm.BatchBlock(), feat, targets=target) + + assert isinstance(outputs, mm.Batch) + assert torch.equal(outputs.feature(), feat) + assert torch.equal(outputs.target(), target) + + def test_forward_exception(self): + with pytest.raises( + RuntimeError, match="Features must be a tensor or a dictionary of tensors" + ): + module_utils.module_test(mm.BatchBlock(), (torch.tensor([1, 2]), torch.tensor([1, 2]))) + + def test_nested(self): + feat, target = torch.tensor([1, 2]), torch.tensor([3, 4]) + outputs = module_utils.module_test(mm.BatchBlock(mm.BatchBlock()), feat, targets=target) + + assert isinstance(outputs, mm.Batch) + assert torch.equal(outputs.feature(), feat) + assert torch.equal(outputs.target(), target) + + def test_in_parallel(self): + feat, target = torch.tensor([1, 2]), torch.tensor([3, 4]) + outputs = module_utils.module_test( + mm.BatchBlock(mm.ParallelBlock({"a": mm.BatchBlock()})), feat, targets=target + ) + + assert isinstance(outputs, mm.Batch) + assert torch.equal(outputs.feature(), feat) + assert torch.equal(outputs.target(), target) + + def test_exception(self): + class BatchToTuple(nn.Module): + def forward( + self, inputs: Dict[str, torch.Tensor], batch: Optional[Batch] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return inputs["default"], inputs["default"] + + feat, target = torch.tensor([1, 2]), torch.tensor([3, 4]) + with pytest.raises(RuntimeError, match="Module must return a Batch"): + module_utils.module_test(mm.BatchBlock(BatchToTuple()), feat, targets=target) diff --git a/tests/unit/torch/test_functional.py b/tests/unit/torch/test_functional.py new file mode 100644 index 0000000000..d750808082 --- /dev/null +++ b/tests/unit/torch/test_functional.py @@ -0,0 +1,300 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Iterable, Tuple + +import pytest +import torch.nn as nn + +import merlin.models.torch as mm +from merlin.models.torch.container import BlockContainer +from merlin.models.torch.functional import _create_list_wrapper + + +class CustomMLP(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x): + return x + self.linear2(self.linear1(x)) + + +def add_relu(x): + if isinstance(x, nn.Linear): + return nn.Sequential(x, nn.ReLU()) + return x + + +def add_relu_named(x, name=None, to_replace="linear1"): + if name == to_replace and isinstance(x, nn.Linear): + return nn.Sequential(x, nn.ReLU()) + return x + + +def add_relu_first(x, i=None): + if i == 0 and isinstance(x, nn.Linear): + return nn.Sequential(x, nn.ReLU()) + return x + + +class TestMapModule: + def test_map_identity(self): + # Test mapping an identity function + module = nn.Linear(10, 10) + identity = lambda x: x # noqa: E731 + assert mm.map(module, identity) is module + + def test_map_transform(self): + # Test mapping a transform function + module = nn.Linear(10, 10) + transformed_module = mm.map(module, add_relu) + assert isinstance(transformed_module[0], nn.Linear) + assert isinstance(transformed_module[1], nn.ReLU) + + def test_walk_custom_module(self): + mlp = CustomMLP() + with_relu = mm.walk(mlp, add_relu) + assert isinstance(with_relu.linear1, nn.Sequential) + assert isinstance(with_relu.linear2, nn.Sequential) + + for fn in [add_relu_named, add_relu_first]: + with_relu_first = mm.walk(mlp, fn) + assert isinstance(with_relu_first.linear1, nn.Sequential) + assert isinstance(with_relu_first.linear2, nn.Linear) + + +class TestMapModuleList: + def test_map_identity(self): + # Test mapping an identity function + modules = nn.ModuleList([nn.Linear(10, 10) for _ in range(5)]) + identity = lambda x: x # noqa: E731 + mapped = mm.map(modules, identity) + assert all(m1 == m2 for m1, m2 in zip(modules, mapped)) + + @pytest.mark.parametrize("wrapper", [nn.Sequential, nn.ModuleList]) + def test_map_with_index(self, wrapper): + # Test mapping a function that uses the index + modules = _create_list_wrapper(wrapper(), [nn.Linear(10, 10) for _ in range(5)]) + + def add_index(x, i): + return nn.Linear(10 + i, 10 + i) + + new_modules = mm.map(modules, add_index) + assert isinstance(new_modules, wrapper) + for i, module in enumerate(new_modules): + assert isinstance(module, nn.Linear) + assert module.in_features == 10 + i + assert module.out_features == 10 + i + + +class TestMapModuleDict: + def test_map_module_dict(self): + # Define a simple transformation function + def transformation(module: nn.Module, name: str = "", **kwargs) -> nn.Module: + if isinstance(module, nn.Linear): + return nn.Linear(20, 10) + return module + + # Define a ModuleDict of modules + module_dict = nn.ModuleDict({"linear1": nn.Linear(10, 10), "linear2": nn.Linear(10, 10)}) + + # Apply map_module_dict + new_module_dict = mm.map(module_dict, transformation) + + # Assert that the transformation has been applied correctly + for module in new_module_dict.values(): + assert isinstance(module, nn.Linear) + assert module.in_features == 20 + assert module.out_features == 10 + + +class TestContainerMixin: + @pytest.fixture + def container(self) -> BlockContainer: + return BlockContainer(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)) + + def test_filter(self, container): + filtered = container.filter(lambda m: isinstance(m, nn.ReLU)) + assert isinstance(filtered, BlockContainer) + assert len(filtered) == 1 + assert isinstance(filtered[0], nn.ReLU) + + def test_filter_recurse(self, container): + def func(module): + return isinstance(module, nn.Linear) + + filtered = BlockContainer(container).filter(func, recurse=True) + assert isinstance(filtered, BlockContainer) + assert len(filtered) == 1 + assert len(filtered[0]) == 2 + assert isinstance(filtered[0][0], nn.Linear) + assert isinstance(filtered[0][1], nn.Linear) + + def test_flatmap(self, container): + def func(module): + return BlockContainer(*([module] * 2)) + + flat_mapped = container.flatmap(func) + assert isinstance(flat_mapped, BlockContainer) + assert len(flat_mapped) == 6 + + def test_flatmap_non_callable(self, container): + with pytest.raises(TypeError): + container.flatmap(123) + + def test_forall(self, container): + def func(module): + return isinstance(module, nn.Module) + + assert container.forall(func) + + def test_forall_recurse(self, container): + def func(module): + return isinstance(module, nn.ReLU) + + assert not BlockContainer(container).forall(func, recurse=True) + assert BlockContainer(container).forall(lambda x: True, recurse=True) + + def test_map(self, container): + def func(module): + if isinstance(module, nn.Linear): + return nn.Conv2d(3, 3, 3) + return module + + mapped = container.map(func) + assert isinstance(mapped, BlockContainer) + assert len(mapped) == 3 + assert isinstance(mapped[0], nn.Conv2d) + assert isinstance(mapped[1], nn.ReLU) + assert isinstance(mapped[2], nn.Conv2d) + + def test_map_recurse(self, container): + def func(module): + if isinstance(module, nn.Linear): + return nn.Conv2d(3, 3, 3) + return module + + mapped = BlockContainer(container).map(func, recurse=True) + assert isinstance(mapped, BlockContainer) + assert len(mapped) == 1 + assert len(mapped[0]) == 3 + assert isinstance(mapped[0][0], nn.Conv2d) + assert isinstance(mapped[0][1], nn.ReLU) + assert isinstance(mapped[0][2], nn.Conv2d) + + def test_mapi(self, container): + def func(module, idx): + assert idx in [0, 1, 2] + + if isinstance(module, nn.Linear): + return nn.Conv2d(3, 3, 3) + return module + + mapped = container.mapi(func) + assert isinstance(mapped, BlockContainer) + assert len(mapped) == 3 + assert isinstance(mapped[0], nn.Conv2d) + assert isinstance(mapped[1], nn.ReLU) + assert isinstance(mapped[2], nn.Conv2d) + + def test_mapi_recurse(self, container): + def func(module, idx): + assert idx in [0, 1, 2] + if isinstance(module, nn.Linear): + return nn.Conv2d(3, 3, 3) + return module + + mapped = BlockContainer(container).mapi(func, recurse=True) + assert isinstance(mapped, BlockContainer) + assert len(mapped) == 1 + assert len(mapped[0]) == 3 + assert isinstance(mapped[0][0], nn.Conv2d) + assert isinstance(mapped[0][1], nn.ReLU) + assert isinstance(mapped[0][2], nn.Conv2d) + + def test_choose(self, container): + def func(module): + if isinstance(module, nn.Linear): + return nn.Conv2d(3, 3, 3) + + chosen = container.choose(func) + assert isinstance(chosen, BlockContainer) + assert len(chosen) == 2 + assert isinstance(chosen[0], nn.Conv2d) + + def test_choose_recurse(self, container): + def func(module): + if isinstance(module, nn.Linear): + return nn.Conv2d(3, 3, 3) + + chosen = BlockContainer(container).choose(func, recurse=True) + assert isinstance(chosen, BlockContainer) + assert len(chosen) == 1 + assert len(chosen[0]) == 2 + assert isinstance(chosen[0][0], nn.Conv2d) + assert isinstance(chosen[0][1], nn.Conv2d) + + def test_walk(self, container: BlockContainer): + def func(module): + if isinstance(module, nn.Linear): + return nn.Conv2d(3, 3, 3) + return module + + walked = BlockContainer(container).walk(func) + assert isinstance(walked, BlockContainer) + assert len(walked) == 1 + assert len(walked[0]) == 3 + assert isinstance(walked[0][0], nn.Conv2d) + assert isinstance(walked[0][1], nn.ReLU) + assert isinstance(walked[0][2], nn.Conv2d) + + def test_zip(self, container: BlockContainer): + other = BlockContainer(nn.Conv2d(3, 3, 3), nn.ReLU(), nn.Linear(5, 2)) + zipped = lambda: container.zip(other) # noqa: E731 + assert isinstance(zipped(), Iterable) + assert len(list(zipped())) == 3 + assert isinstance(list(zipped())[0], Tuple) + assert isinstance(list(zipped())[0][0], nn.Linear) + assert isinstance(list(zipped())[0][1], nn.Conv2d) + + def test_add(self, container): + new_module = nn.Linear(5, 2) + new_container = container + new_module + assert isinstance(new_container, BlockContainer) + assert len(new_container) == 4 + assert isinstance(new_container[3], nn.Linear) + + _container = container + container + assert len(_container) == 6 + + def test_radd(self, container): + new_module = nn.Linear(5, 2) + new_container = new_module + container + assert isinstance(new_container, BlockContainer) + assert len(new_container) == 4 + assert isinstance(new_container[0], nn.Linear) + + def test_freeze(self, container): + container.freeze() + for param in container.parameters(): + assert not param.requires_grad + + def test_unfreeze(self, container): + container.unfreeze() + for param in container.parameters(): + assert param.requires_grad diff --git a/tests/unit/torch/test_predict.py b/tests/unit/torch/test_predict.py index d559bcd0e9..1a66e1111c 100644 --- a/tests/unit/torch/test_predict.py +++ b/tests/unit/torch/test_predict.py @@ -1,3 +1,5 @@ +from typing import Dict + import pandas as pd import pytest import torch @@ -5,12 +7,12 @@ from merlin.dataloader.torch import Loader from merlin.io import Dataset -from merlin.models.torch.predict import Encoder, Predictor +from merlin.models.torch.predict import DaskEncoder, DaskPredictor, EncoderBlock from merlin.schema import Tags class TensorOutputModel(nn.Module): - def forward(self, x): + def forward(self, x: Dict[str, torch.Tensor]): return x["position"] * 2 @@ -19,25 +21,77 @@ def __init__(self, output_name: str = "testing"): super().__init__() self.name = output_name - def forward(self, x): + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {self.name: x["user_id"] * 2} -class TestEncoder: +class TestEncoderBlock: + def test_encode_loader(self, music_streaming_data): + loader = Loader(music_streaming_data, batch_size=10) + + ddf = music_streaming_data.to_ddf() + num_items = ddf["item_id"].nunique().compute() + + encoder = EncoderBlock(TensorOutputModel()) + outputs = encoder.encode(loader, index=Tags.ITEM_ID).compute() + + assert outputs.index.name == "item_id" + assert len(outputs) == num_items + + def test_encode_dataset(self, music_streaming_data): + encoder = EncoderBlock(DictOutputModel()) + + output = encoder.encode( + music_streaming_data, + selection=Tags.USER, + batch_size=10, + index=Tags.USER_ID, + unique=False, + ) + output_df = output.compute() + assert len(output_df) == 100 + assert set(output.schema.column_names) == {"testing"} + assert output_df.index.name == "user_id" + + def test_predict_dataset(self, music_streaming_data): + predictor = EncoderBlock(DictOutputModel("click")) + output = predictor.predict(music_streaming_data, batch_size=10, prediction_suffix="_") + output_df = output.compute() + assert len(output_df) == 100 + + for col in music_streaming_data.schema.column_names: + assert col in output_df.columns + + assert "click_" in output_df.columns + assert "click_" in output.schema.column_names + assert len(output_df.columns) == len(output.schema) + + def test_predict_no_targets(self): + predictor = EncoderBlock(TensorOutputModel()) + + df = pd.DataFrame({"position": [1, 2, 3, 4]}) + outputs = predictor.predict(Dataset(df), batch_size=2) + output_df = outputs.compute() + + assert len(outputs.schema) == 2 + assert hasattr(output_df, "columns") + + +class TestDaskEncoder: def test_loader(self, music_streaming_data): loader = Loader(music_streaming_data, batch_size=10) ddf = music_streaming_data.to_ddf() num_items = ddf["item_id"].nunique().compute() - encoder = Encoder(TensorOutputModel()) + encoder = DaskEncoder(TensorOutputModel()) outputs = encoder(loader, index=Tags.ITEM_ID).compute() assert outputs.index.name == "item_id" assert len(outputs) == num_items def test_dataset(self, music_streaming_data): - encoder = Encoder(DictOutputModel(), selection=Tags.USER) + encoder = DaskEncoder(DictOutputModel(), selection=Tags.USER) with pytest.raises(ValueError): encoder(music_streaming_data) @@ -49,28 +103,28 @@ def test_dataset(self, music_streaming_data): assert output_df.index.name == "user_id" def test_tensor_dict(self): - encoder = Encoder(TensorOutputModel()) + encoder = DaskEncoder(TensorOutputModel()) outputs = encoder({"position": torch.tensor([1, 2, 3, 4])}) assert len(outputs) == 4 assert hasattr(outputs, "columns") def test_tensor(self): - encoder = Encoder(nn.Identity()) + encoder = DaskEncoder(nn.Identity()) outputs = encoder(torch.tensor([1, 2, 3, 4])) assert len(outputs) == 4 assert hasattr(outputs, "columns") def test_df(self): - encoder = Encoder(nn.Identity()) + encoder = DaskEncoder(nn.Identity()) outputs = encoder(pd.DataFrame({"a": [1, 2, 3, 4]})) assert len(outputs) == 4 assert hasattr(outputs, "columns") def test_exceptions(self): - encoder = Encoder(DictOutputModel()) + encoder = DaskEncoder(DictOutputModel()) with pytest.raises(ValueError): encoder("") @@ -79,9 +133,9 @@ def test_exceptions(self): encoder.encode_dataset(torch.tensor([1, 2, 3])) -class TestPredictor: +class TestDaskPredictor: def test_dataset(self, music_streaming_data): - predictor = Predictor(DictOutputModel("click"), prediction_suffix="_") + predictor = DaskPredictor(DictOutputModel("click"), prediction_suffix="_") output = predictor(music_streaming_data, batch_size=10) output_df = output.compute() assert len(output_df) == 100 @@ -94,7 +148,7 @@ def test_dataset(self, music_streaming_data): assert len(output_df.columns) == len(output.schema) def test_no_targets(self): - predictor = Predictor(TensorOutputModel()) + predictor = DaskPredictor(TensorOutputModel()) df = pd.DataFrame({"position": [1, 2, 3, 4]}) outputs = predictor(Dataset(df), batch_size=2) @@ -104,7 +158,7 @@ def test_no_targets(self): assert hasattr(output_df, "columns") def test_no_targets_dict(self): - predictor = Predictor(DictOutputModel()) + predictor = DaskPredictor(DictOutputModel()) df = pd.DataFrame({"user_id": [1, 2, 3, 4]}) outputs = predictor(Dataset(df), batch_size=2) diff --git a/tests/unit/torch/test_router.py b/tests/unit/torch/test_router.py index 89f9292f6c..baa981e3c9 100644 --- a/tests/unit/torch/test_router.py +++ b/tests/unit/torch/test_router.py @@ -69,7 +69,7 @@ class CustomSelect(mm.SelectKeys): def test_module_with_setup(self): class Dummy(nn.Module): - def setup_schema(self, schema: Schema): + def initialize_from_schema(self, schema: Schema): self.schema = schema def forward(self, x): @@ -162,4 +162,13 @@ def test_nested(self): outputs = module_utils.module_test(nested, self.batch.features) assert list(outputs.keys()) == ["user_age"] - assert "user_age" in mm.schema.output(nested).column_names + assert "user_age" in mm.output_schema(nested).column_names + + def test_exceptions(self): + router = mm.RouterBlock(None) + with pytest.raises(ValueError): + router.add_route(Tags.CONTINUOUS) + + router = mm.RouterBlock(self.schema, prepend_routing_module=False) + with pytest.raises(ValueError): + router.add_route(Tags.CONTINUOUS) diff --git a/tests/unit/torch/test_schema.py b/tests/unit/torch/test_schema.py index d6919c2617..f3876b0679 100644 --- a/tests/unit/torch/test_schema.py +++ b/tests/unit/torch/test_schema.py @@ -14,16 +14,22 @@ # limitations under the License. # +from typing import Dict + import pytest +import torch from torch import nn +from merlin.models.torch.batch import Batch +from merlin.models.torch.block import ParallelBlock from merlin.models.torch.schema import ( Selectable, - features, + feature_schema, select, select_schema, selection_name, - targets, + target_schema, + trace, ) from merlin.schema import ColumnSchema, Schema, Tags @@ -60,6 +66,10 @@ def test_select_column(self): output_2 = select(self.schema, ColumnSchema("user_id")) assert output == output_2 == Schema([column]) + def test_select_star(self): + output = select(self.schema, "*") + assert output == self.schema + def test_exceptions(self): with pytest.raises(ValueError, match="is not valid"): select(self.schema, 1) @@ -104,26 +114,30 @@ class TestSelectable: def test_exception(self): selectable = Selectable() - selectable.setup_schema(Schema([])) + selectable.initialize_from_schema(Schema([])) selectable.schema == Schema([]) with pytest.raises(NotImplementedError): selectable.select(1) class MockModule(nn.Module): - def __init__(self, feature_schema=None, target_schema=None): + def __init__(self, target_schema=None): super().__init__() - self.feature_schema = feature_schema self.target_schema = target_schema + def forward(self, inputs, batch: Batch): + return batch.features + class TestFeatures: def test_features(self): - schema = Schema([ColumnSchema("a"), ColumnSchema("b")]) - - module = MockModule(feature_schema=schema) - assert features(module) == schema - assert targets(module) == Schema() + module = MockModule() + features = {"a": torch.tensor([1]), "b": torch.tensor([2.3])} + trace(module, {}, batch=Batch(features)) + assert feature_schema(module) == Schema( + [ColumnSchema("a", dtype="int64"), ColumnSchema("b", dtype="float32")] + ) + assert target_schema(module) == Schema() class TestTargets: @@ -131,5 +145,66 @@ def test_targets(self): schema = Schema([ColumnSchema("a"), ColumnSchema("b")]) module = MockModule(target_schema=schema) - assert targets(module) == schema - assert features(module) == Schema() + assert target_schema(module) == schema + assert feature_schema(module) == Schema() + + +class TestTraceInitializeFromSchema: + """Testing initialize_from_schema works with tracing.""" + + def test_simple(self): + class Dummy(nn.Module): + def initialize_from_schema(self, schema: Schema): + self.schema = schema + + def forward(self, x): + return x + + module = Dummy() + trace(module, {"a": torch.tensor([1])}) + assert module.schema.column_names == ["a"] + + def test_parallel_tensor(self): + class Dummy(nn.Module): + def initialize_from_schema(self, schema: Schema): + self.schema = schema + + def forward(self, x: torch.Tensor): + return x + + dummy = Dummy() + identity = nn.Identity() + module = ParallelBlock({"foo": dummy, "bar": identity}) + trace(module, torch.tensor([1])) + assert dummy.schema.column_names == ["input"] + + def test_parallel_dict(self): + class Dummy(nn.Module): + def initialize_from_schema(self, schema: Schema): + self.schema = schema + + def forward(self, x: Dict[str, torch.Tensor]): + return x + + dummy = Dummy() + module = ParallelBlock({"foo": dummy}) + trace(module, {"a": torch.tensor([1])}) + assert dummy.schema.column_names == ["a"] + + def test_sequential(self): + class Dummy(nn.Module): + def initialize_from_schema(self, schema: Schema): + self.schema = schema + + def forward(self, x: Dict[str, torch.Tensor]): + output = {} + for k, v in x.items(): + output[f"{k}_output"] = v + return output + + first = Dummy() + second = Dummy() + module = nn.Sequential(first, second) + trace(module, {"a": torch.tensor([1])}) + assert first.schema.column_names == ["a"] + assert second.schema.column_names == ["a_output"] diff --git a/tests/unit/torch/transforms/test_agg.py b/tests/unit/torch/transforms/test_agg.py index 597226fb96..ae0177526a 100644 --- a/tests/unit/torch/transforms/test_agg.py +++ b/tests/unit/torch/transforms/test_agg.py @@ -2,7 +2,7 @@ import torch from merlin.models.torch.block import Block -from merlin.models.torch.transforms.agg import Concat, MaybeAgg, Stack +from merlin.models.torch.transforms.agg import Concat, ElementWiseSum, MaybeAgg, Stack from merlin.models.torch.utils import module_utils from merlin.schema import Schema @@ -95,6 +95,30 @@ def test_from_registry(self): assert output.shape == (2, 2, 3) +class TestElementWiseSum: + def setup_class(self): + self.ewsum = ElementWiseSum() + self.feature1 = torch.tensor([[1, 2], [3, 4]]) # Shape: [batch_size, feature_dim] + self.feature2 = torch.tensor([[5, 6], [7, 8]]) # Shape: [batch_size, feature_dim] + self.input_dict = {"feature1": self.feature1, "feature2": self.feature2} + + def test_forward(self): + output = self.ewsum(self.input_dict) + expected_output = torch.tensor([[6, 8], [10, 12]]) # Shape: [batch_size, feature_dim] + assert torch.equal(output, expected_output) + + def test_input_tensor_shape_mismatch(self): + feature_mismatch = torch.tensor([1, 2, 3]) # Different shape + input_dict_mismatch = {"feature1": self.feature1, "feature_mismatch": feature_mismatch} + with pytest.raises(RuntimeError): + self.ewsum(input_dict_mismatch) + + def test_empty_input_dict(self): + empty_dict = {} + with pytest.raises(RuntimeError): + self.ewsum(empty_dict) + + class TestMaybeAgg: def test_with_single_tensor(self): tensor = torch.tensor([1, 2, 3]) diff --git a/tests/unit/torch/transforms/test_bias.py b/tests/unit/torch/transforms/test_bias.py new file mode 100644 index 0000000000..4f29ff592e --- /dev/null +++ b/tests/unit/torch/transforms/test_bias.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from merlin.models.torch.transforms.bias import LogitsTemperatureScaler +from merlin.models.torch.utils import module_utils + + +class TestLogitsTemperatureScaler: + def test_init(self): + """Test correct temperature initialization.""" + scaler = LogitsTemperatureScaler(0.5) + assert scaler.temperature == 0.5 + + def test_invalid_temperature_type(self): + """Test exception is raised for incorrect temperature type.""" + with pytest.raises(ValueError, match=r"Invalid temperature type"): + LogitsTemperatureScaler("invalid") + + def test_invalid_temperature_value(self): + """Test exception is raised for out-of-range temperature value.""" + with pytest.raises(ValueError, match=r"Invalid temperature value"): + LogitsTemperatureScaler(1.5) + + def test_temperature_scaling(self): + """Test temperature scaling of logits.""" + logits = torch.tensor([1.0, 2.0, 3.0]) + expected_scaled_logits = torch.tensor([2.0, 4.0, 6.0]) + + scaler = LogitsTemperatureScaler(0.5) + outputs = module_utils.module_test(scaler, logits) + assert torch.allclose(outputs, expected_scaled_logits) + + def test_zero_temperature_value(self): + """Test exception is raised for zero temperature value.""" + with pytest.raises(ValueError, match=r"Invalid temperature value"): + LogitsTemperatureScaler(0.0) diff --git a/tests/unit/torch/transforms/test_sequences.py b/tests/unit/torch/transforms/test_sequences.py new file mode 100644 index 0000000000..53d41015d6 --- /dev/null +++ b/tests/unit/torch/transforms/test_sequences.py @@ -0,0 +1,279 @@ +from itertools import accumulate + +import pytest +import torch + +from merlin.models.torch.batch import Batch, Sequence +from merlin.models.torch.transforms.sequences import ( + BroadcastToSequence, + TabularPadding, + TabularPredictNext, +) +from merlin.models.torch.utils import module_utils +from merlin.schema import ColumnSchema, Schema, Tags + + +def _get_values_offsets(data): + values = [] + row_lengths = [] + for row in data: + row_lengths.append(len(row)) + values += row + offsets = [0] + list(accumulate(row_lengths)) + return torch.tensor(values), torch.tensor(offsets) + + +class TestTabularPadding: + @pytest.fixture + def sequence_batch(self): + a_values, a_offsets = _get_values_offsets(data=[[1, 2], [], [3, 4, 5]]) + b_values, b_offsets = _get_values_offsets([[34, 30], [], [33, 23, 50]]) + features = { + "a__values": a_values, + "a__offsets": a_offsets, + "b__values": b_values, + "b__offsets": b_offsets, + "c_dense": torch.Tensor([[1, 2, 0], [0, 0, 0], [4, 5, 6]]), + "d_context": torch.Tensor([1, 2, 3]), + } + targets = None + return Batch(features, targets) + + @pytest.fixture + def sequence_schema(self): + return Schema( + [ + ColumnSchema("a", tags=[Tags.SEQUENCE]), + ColumnSchema("b", tags=[Tags.SEQUENCE]), + ColumnSchema("c_dense", tags=[Tags.SEQUENCE]), + ColumnSchema("d_context", tags=[Tags.CONTEXT]), + ] + ) + + def test_padded_features(self, sequence_batch, sequence_schema): + _max_sequence_length = 8 + padding_op = TabularPadding( + schema=sequence_schema, max_sequence_length=_max_sequence_length + ) + padded_batch = module_utils.module_test(padding_op, sequence_batch) + + assert torch.equal(padded_batch.sequences.length("a"), torch.Tensor([2, 0, 3])) + assert set(padded_batch.features.keys()) == set(["a", "b", "c_dense", "d_context"]) + for feature in ["a", "b", "c_dense"]: + assert padded_batch.features[feature].shape[1] == _max_sequence_length + + def test_batch_invalid_lengths(self): + # Test when targets is not a tensor nor a dictionary of tensors + a_values, a_offsets = _get_values_offsets(data=[[1, 2], [], [3, 4, 5]]) + b_values, b_offsets = _get_values_offsets([[34], [23, 56], [33, 23, 50, 4]]) + + with pytest.raises( + ValueError, + match="The sequential inputs must have the same length for each row in the batch", + ): + padding_op = TabularPadding(schema=Schema(["a", "b"]), selection=None) + padding_op( + inputs=None, + batch=Batch( + { + "a__values": a_values, + "a__offsets": a_offsets, + "b__values": b_values, + "b__offsets": b_offsets, + } + ), + ) + + def test_padded_targets(self, sequence_batch, sequence_schema): + _max_sequence_length = 8 + target_values, target_offsets = _get_values_offsets([[10, 11], [], [12, 13, 14]]) + sequence_batch.targets = { + "target_1": torch.Tensor([3, 4, 6]), + "target_2__values": target_values, + "target_2__offsets": target_offsets, + } + padding_op = TabularPadding( + schema=sequence_schema, max_sequence_length=_max_sequence_length + ) + padded_batch = module_utils.module_test(padding_op, sequence_batch) + + assert padded_batch.targets["target_2"].shape[1] == _max_sequence_length + assert torch.equal(padded_batch.targets["target_1"], sequence_batch.targets["target_1"]) + + +class TestBroadcastToSequence: + def setup_method(self): + self.input_tensors = { + "feature_1": torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + "feature_2": torch.tensor( + [[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], [[4.0, 4.0], [5.0, 5.0], [6.0, 6.0]]] + ), + "feature_3": torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]), + } + self.schema = Schema(list(self.input_tensors.keys())) + self.to_broadcast = Schema(["feature_1", "feature_3"]) + self.sequence = Schema(["feature_2"]) + self.broadcast = BroadcastToSequence(self.to_broadcast, self.sequence) + + def test_initialize_from_schema(self): + self.broadcast.initialize_from_schema(self.schema) + assert self.broadcast.to_broadcast_features == ["feature_1", "feature_3"] + assert self.broadcast.sequence_features == ["feature_2"] + + def test_get_seq_length(self): + self.broadcast.initialize_from_schema(self.schema) + assert self.broadcast.get_seq_length(self.input_tensors) == 3 + + def test_get_seq_length_offsets(self): + self.broadcast.initialize_from_schema(self.schema) + + inputs = { + "feature_1": torch.tensor([1, 2]), + "feature_2__offsets": torch.tensor([2, 3]), + "feature_3": torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]), + } + + assert self.broadcast.get_seq_length(inputs) == 3 + + def test_forward(self): + self.broadcast.initialize_from_schema(self.schema) + output = module_utils.module_test(self.broadcast, self.input_tensors) + assert output["feature_1"].shape == (2, 3, 3) + assert output["feature_3"].shape == (2, 3, 3) + assert output["feature_2"].shape == (2, 3, 2) + + def test_unsupported_dimensions(self): + self.broadcast.initialize_from_schema(self.schema) + self.input_tensors["feature_3"] = torch.rand(10, 3, 3, 3) + + with pytest.raises(RuntimeError, match="Unsupported number of dimensions: 4"): + self.broadcast(self.input_tensors) + + +class TestTabularPredictNext: + @pytest.fixture + def sequence_batch(self): + a_values, a_offsets = _get_values_offsets(data=[[1, 2, 3], [3, 6], [3, 4, 5, 6]]) + b_values, b_offsets = _get_values_offsets([[34, 30, 31], [30, 31], [33, 23, 50, 51]]) + + c_values, c_offsets = _get_values_offsets([[1, 2, 3, 4], [5, 6], [5, 6, 7, 8, 9, 10]]) + d_values, d_offsets = _get_values_offsets( + [[10, 20, 30, 40], [50, 60], [50, 60, 70, 80, 90, 100]] + ) + + features = { + "a__values": a_values, + "a__offsets": a_offsets, + "b__values": b_values, + "b__offsets": b_offsets, + "c__values": c_values, + "c__offsets": c_offsets, + "d__values": d_values, + "d__offsets": d_offsets, + "e_dense": torch.Tensor([[1, 2, 3, 0], [5, 6, 0, 0], [4, 5, 6, 7]]), + "f_context": torch.Tensor([1, 2, 3, 4]), + } + targets = None + return Batch(features, targets) + + @pytest.fixture + def sequence_schema_1(self): + return Schema( + [ + ColumnSchema("a", tags=[Tags.SEQUENCE]), + ColumnSchema("b", tags=[Tags.SEQUENCE]), + ColumnSchema("e_dense", tags=[Tags.SEQUENCE]), + ] + ) + + @pytest.fixture + def sequence_schema_2(self): + return Schema( + [ + ColumnSchema("c", tags=[Tags.SEQUENCE, Tags.ID]), + ColumnSchema("d", tags=[Tags.SEQUENCE]), + ] + ) + + @pytest.fixture + def padded_batch(self, sequence_schema_1, sequence_batch): + padding_op = TabularPadding(schema=sequence_schema_1) + return padding_op(sequence_batch) + + def test_tabular_sequence_transform_wrong_inputs(self, padded_batch, sequence_schema_1): + with pytest.raises( + ValueError, + match="The target 'Tags.ID' was not found in the provided sequential schema:", + ): + transform = TabularPredictNext( + schema=sequence_schema_1, + target=Tags.ID, + ) + + transform = TabularPredictNext( + schema=sequence_schema_1, + target="a", + apply_padding=False, + ) + with pytest.raises( + ValueError, + match="The input `batch` should include information about input sequences lengths", + ): + transform(Batch({"b": padded_batch.features["b"]})) + + with pytest.raises( + ValueError, + match="Inputs features do not contain target column", + ): + transform(Batch({"b": padded_batch.features["b"]}, sequences=padded_batch.sequences)) + + with pytest.raises( + ValueError, match="must be greater than 1 for sequential input to be shifted as target" + ): + transform = TabularPredictNext( + schema=sequence_schema_1.select_by_name("a"), target="a", apply_padding=False + ) + transform( + Batch( + {"a": torch.Tensor([[1, 2], [1, 0], [3, 4]])}, + sequences=Sequence(lengths={"a": torch.Tensor([2, 1, 2])}), + ) + ) + + def test_transform_predict_next(self, sequence_batch, padded_batch, sequence_schema_1): + transform = TabularPredictNext(schema=sequence_schema_1, target="a") + + batch_output = module_utils.module_test(transform, sequence_batch) + + assert list(batch_output.features.keys()) == ["a", "b", "e_dense"] + for k in ["a", "b", "e_dense"]: + assert torch.equal(batch_output.features[k], padded_batch.features[k][:, :-1]) + assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([2, 1, 3])) + + def test_transform_predict_next_multi_sequence( + self, sequence_batch, padded_batch, sequence_schema_1, sequence_schema_2 + ): + import merlin.models.torch as mm + + transform_1 = TabularPredictNext(schema=sequence_schema_1, target="a") + transform_2 = TabularPredictNext(schema=sequence_schema_2) + transform_block = mm.BatchBlock( + mm.ParallelBlock({"transform_1": transform_1, "transform_2": transform_2}) + ) + batch_output = transform_block(sequence_batch) + + assert list(batch_output.features.keys()) == ["a", "b", "e_dense", "f_context", "c", "d"] + assert list(batch_output.targets.keys()) == ["a", "c"] + + assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([2, 1, 3])) + assert torch.equal(batch_output.sequences.length("c"), torch.Tensor([3, 1, 5])) + assert torch.all( + batch_output.sequences.mask("a") # target mask + == torch.Tensor( + [ + [True, True, False], + [True, False, False], + [True, True, True], + ] + ) + ) diff --git a/tests/unit/torch/transforms/test_tuple.py b/tests/unit/torch/transforms/test_tuple.py new file mode 100644 index 0000000000..05dd4a6f2c --- /dev/null +++ b/tests/unit/torch/transforms/test_tuple.py @@ -0,0 +1,28 @@ +import pytest +import torch + +from merlin.models.torch.transforms import tuple +from merlin.models.torch.utils import module_utils +from merlin.schema import Schema + + +class TestToTuple: + @pytest.mark.parametrize("length", [i + 1 for i in range(10)]) + def test_with_length(self, length): + schema = Schema([str(i) for i in range(length)]) + to_tuple = tuple.ToTuple(schema) + assert isinstance(to_tuple, getattr(tuple, f"ToTuple{length}")) + + inputs = {str(i): torch.randn(2, 3) for i in range(length)} + outputs = module_utils.module_test(to_tuple, inputs) + + assert len(outputs) == length + + def test_exception(self): + with pytest.raises(ValueError): + tuple.ToTuple(Schema([str(i) for i in range(11)])) + + to_tuple = tuple.ToTuple2() + inputs = {"0": torch.randn(2, 3), "1": torch.randn(2, 3)} + with pytest.raises(RuntimeError): + module_utils.module_test(to_tuple, inputs) diff --git a/tests/unit/torch/utils/test_traversal_utils.py b/tests/unit/torch/utils/test_traversal_utils.py index d44b53e55d..33bdd10822 100644 --- a/tests/unit/torch/utils/test_traversal_utils.py +++ b/tests/unit/torch/utils/test_traversal_utils.py @@ -1,4 +1,3 @@ -import pytest from torch import nn import merlin.models.torch as mm @@ -86,10 +85,9 @@ def __init__(self): model = CustomModule() assert isinstance(leaf(model), CustomModule) - def test_exception(self): - model = nn.Sequential(nn.Linear(10, 20), nn.Linear(10, 20)) - with pytest.raises(ValueError): - leaf(model) + def test_sequential(self): + model = nn.Sequential(nn.Linear(10, 20), nn.Linear(10, 30)) + assert leaf(model).out_features == 30 def test_embedding(self, user_id_col_schema): input_block = mm.TabularInputBlock(Schema([user_id_col_schema]), init="defaults") diff --git a/tox.ini b/tox.ini index 770abbaa00..2e4a07fd40 100644 --- a/tox.ini +++ b/tox.ini @@ -54,8 +54,7 @@ commands = bash -c 'cp $(python -c "import sys; print(sys.base_prefix)")/lib/*.so* $(python -c "import sys; print(sys.prefix)")/lib' bash -c 'python -m pytest --cov-report term --cov merlin -m "{env:PYTEST_MARKERS}" -rxs {posargs:tests} || ([ $? = 5 ] && exit 0 || exit $?)' - -[testenv:multi-gpu] +[testenv:horovod-gpu] ; Runs in: Github Actions ; Runs GPU-based tests. allowlist_externals = @@ -76,7 +75,7 @@ deps = git+https://github.com/NVIDIA-Merlin/NVTabular.git@{env:MERLIN_BRANCH} commands = sh examples/usecases/multi-gpu/install_sparse_operation_kit.sh {envdir} - bash -c 'horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh python -m pytest -m "unit and horovod {env:EXTRA_PYTEST_MARKERS}" -rxs {posargs:tests} || ([ $? = 5 ] && exit 0 || exit $?)' + bash -c 'horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh python -m pytest -m "unit and horovod {env:PYTEST_MARKERS}" -rxs {posargs:tests} || ([ $? = 5 ] && exit 0 || exit $?)' [testenv:horovod-cpu] setenv = @@ -93,7 +92,7 @@ commands = conda env create --prefix {envdir}/env --file requirements/horovod-cpu-environment.yml --force {envdir}/env/bin/python -m pip install 'horovod==0.27.0' --no-cache-dir {envdir}/env/bin/horovodrun --check-build - {envdir}/env/bin/horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh pytest -m "unit and horovod {env:EXTRA_PYTEST_MARKERS}" -rxs {posargs:tests} + {envdir}/env/bin/horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh pytest -m "unit and horovod {env:PYTEST_MARKERS}" -rxs {posargs:tests} [testenv:nvtabular-cpu] passenv=GIT_COMMIT