Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
sheepforce committed Dec 16, 2024
1 parent 063c427 commit 9a307e4
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 241 deletions.
231 changes: 0 additions & 231 deletions data/trexio.json

This file was deleted.

6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions get-json-spec.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#! /usr/bin/env bash

# Create a temporary C file, that merely imports trexio.h
TREXIO_TMP=$(mktemp --suffix=.c)
trap "rm -f $TREXIO_TMP" EXIT ERR

# Run the C preprocessor on that file to get the path to trexio.h
echo "#include <trexio.h>" > $TREXIO_TMP
TREXIO_HEADER_PATH=$(gcc -E test.c | grep '^# [0-9]\+ "' | grep 'trexio\.h' | head -n 1 | awk -F'"' '{print $2}')

# Extract the JSON specification from a comment in the header file
sed -n '/\/\* JSON configuration/,/\*\//p' $TREXIO_HEADER_PATH | sed '1d;$d'

4 changes: 3 additions & 1 deletion nix/overlay.nix
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ final: prev: {
in
prev.haskell.packageOverrides hfinal hprev
// {
trexio-hs = hfinal.callCabal2nix "trexio-hs" ../. { inherit (final) trexio; };
trexio-hs = hfinal.callCabal2nix "trexio-hs" ../. {
inherit (final) trexio;
};
};
};
}
44 changes: 42 additions & 2 deletions src/TREXIO/Internal/TH.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@ import Data.Char
import Data.Coerce
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Massiv.Array as Massiv hiding (Dim, forM, forM_, mapM, product, replicate, toList, zip, throwM)
import Data.Massiv.Array as Massiv hiding (Dim, forM, forM_, mapM, product, replicate, throwM, toList, zip)
import Data.Massiv.Array qualified as Massiv
import Data.Massiv.Array.Manifest.Vector qualified as Massiv
import Data.Massiv.Array.Unsafe (unsafeWithPtr)
import Data.Maybe
import Data.Text (Text)
import Data.Text qualified as T
import Data.Vector qualified as V
import Foreign hiding (peekArray, withArray)
import Foreign hiding (peekArray, void, withArray)
import Foreign.C.ConstPtr
import Foreign.C.String
import Foreign.C.Types
import GHC.Generics (Generic)
import Language.Haskell.TH
import Language.Haskell.TH.Syntax (Lift (..))
import System.Process.Typed
import TREXIO.CooArray
import TREXIO.Internal.Base
import TREXIO.Internal.Marshaller
Expand All @@ -41,6 +42,21 @@ tshow = T.pack . show

--------------------------------------------------------------------------------

{- | Attempts to obtain the JSON specification from the trexio.h header. Magic
happens in a bash script
-}
getJsonSpec :: (MonadIO m, MonadThrow m) => m TrexioScheme
getJsonSpec = do
(ec, stdout, stderr) <- readProcess . shell $ "./get-json-spec.sh"
jsonSpec <- case ec of
ExitSuccess -> return stdout
ExitFailure _ -> throwString . show $ stderr
case eitherDecode jsonSpec of
Left err -> throwString err
Right spec -> return spec

--------------------------------------------------------------------------------

{- | The overall data structure TREXIO uses to represent a wave function as a
JSON specification. A TREXIO scheme consists of multiple data groups and each
data group has multiple fields. A field may require knowledge of other fields.
Expand Down Expand Up @@ -659,6 +675,20 @@ mkReadFns groupName dataName fieldType = case dims of
|]
| otherwise -> error $ "mkReadFns: unsupported field type for 3D data: " <> show fieldType
[d1, d2, d3, d4]
| isFloatField fieldType ->
[e|
\trexio ->
liftIO $ do
sz1 <- $(mkSizeFn d1) trexio
sz2 <- $(mkSizeFn d2) trexio
sz3 <- $(mkSizeFn d3) trexio
sz4 <- $(mkSizeFn d4) trexio
allocaArray (sz1 * sz2 * sz3 * sz4) $ \buf -> do
ec <- exitCodeH <$> $(varE . mkName $ mkCFnName Read groupName dataName) trexio buf
case ec of
Success -> peekArray (Sz4 sz1 sz2 sz3 sz4) (castPtr buf)
_ -> throwM ec
|]
| isSparseFloat fieldType ->
[e|
\trexio -> liftIO $ do
Expand Down Expand Up @@ -982,6 +1012,16 @@ mkWriteFns scheme groupName dataName fieldType = case dims of
|]
| otherwise -> error $ "mkWriteFns: unsupported field type for 3D data: " <> show fieldType
[d1, d2, d3, d4]
| isFloatField fieldType ->
[e|
\trexio arr -> liftIO . unsafeWithPtr arr $ \arrPtr -> do
let Sz4 sz1 sz2 sz3 sz4 = size arr
$(mkWriteSzFn scheme d1) trexio sz1
$(mkWriteSzFn scheme d2) trexio sz2
$(mkWriteSzFn scheme d3) trexio sz3
$(mkWriteSzFn scheme d4) trexio sz4
checkEC $ $(varE . mkName $ mkCFnName Write groupName dataName) trexio (castPtr arrPtr)
|]
| isSparseFloat fieldType ->
[e|
\trexio cooArr -> liftIO $ do
Expand Down
2 changes: 1 addition & 1 deletion src/TREXIO/LowLevel/Scheme.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ import Language.Haskell.TH.Syntax (lift)

scheme :: TrexioScheme
scheme = $(do
Just trexio <- runIO $ decodeFileStrict @TrexioScheme "./data/trexio.json"
trexio <- runIO getJsonSpec
lift trexio
)
7 changes: 4 additions & 3 deletions trexio-hs.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ maintainer: [email protected]
category: Data
build-type: Simple
extra-doc-files: CHANGELOG.md
data-files: data/trexio.json
data-files: get-json-spec.sh
tested-with: GHC == {9.6.6, 9.8.2}
description:
This package provides low- and high-level Haskell bindings for [TREXIO, a portable file format for storing wave function data](https://trex-coe.github.io/trexio/).
Expand Down Expand Up @@ -81,6 +81,7 @@ common deps
massiv >= 1.0.0.0 && < 1.1,
safe-exceptions >= 0.1.7 && < 0.2,
template-haskell >= 2.20 && < 2.23,
typed-process >= 0.2.12 && < 0.3,
text >= 2.0 && < 2.2,
vector >= 0.13 && < 0.14

Expand Down Expand Up @@ -147,6 +148,6 @@ test-suite trexio-test
tasty >= 1.4 && < 1.6,
tasty-hunit >= 0.10 && < 0.11,
tasty-hedgehog >= 1.4 && < 1.5,
hedgehog >= 1.4 && < 1.6,
temporary >= 1.3 && < 1.4
temporary >= 1.3 && < 1.4,
hedgehog >= 1.4 && < 1.6

0 comments on commit 9a307e4

Please sign in to comment.