Skip to content

Commit

Permalink
Merge pull request #517 from llaniewski/feature/autosym
Browse files Browse the repository at this point in the history
Introducing `autosym2`
  • Loading branch information
llaniewski authored Jun 3, 2024
2 parents ac87c5f + f4c66ff commit 68e4df6
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 21 deletions.
42 changes: 33 additions & 9 deletions src/LatticeAccess.inc.cpp.Rt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
#define real_to_storage_shift(x__,y__) real_to_storage(x__)
#endif

#ifndef NODE_SYMZ
#define NODE_SYMZ 0
#endif

/// Get only type of node
CudaDeviceFunction flag_t LatticeContainer::getType(int x, int y, int z) const
{
Expand Down Expand Up @@ -417,21 +421,23 @@ public:

si = which(Fields$name == fn)
sf = rows(Fields)[[si]]
sd = d
sd[i] = sd[i]*(-1)
sd_plus = d
sd_plus[i] = autosym_shift-sd_plus[i]
sd_minus = d
sd_minus[i] = -autosym_shift-sd_minus[i]
?>
template < class PARENT >
template <class dx_t, class dy_t, class dz_t>
CudaDeviceFunction real_t SymmetryAccess< PARENT >::<?%s paste0(this_fun, f$nicename) ?> (const dx_t & dx, const dy_t & dy, const dz_t & dz) const
{
<?R if (paste0("SYM",ch[i]) %in% NodeTypes$group) { ?>
if (<?R C(d[i]) ?> > range_int<0>()) {
if ((this->getNodeType() & NODE_SYM<?%s ch[i] ?>) == NODE_Symmetry<?%s ch[i] ?>_plus) {
return <?%s paste0(sig, next_fun, sf$nicename) ?>(<?R C(sd,sep=", ") ?>);
if ((this->getNodeType() & NODE_SYM<?%s ch[i] ?>) == NODE_<?%s autosym_name ?><?%s ch[i] ?>_plus) {
return <?%s paste0(sig, next_fun, sf$nicename) ?>(<?R C(sd_plus,sep=", ",float=FALSE,wrap.const=range_int) ?>);
}
} else if (<?R C(d[i]) ?> < range_int<0>()) {
if ((this->getNodeType() & NODE_SYM<?%s ch[i] ?>) == NODE_Symmetry<?%s ch[i] ?>_minus) {
return <?%s paste0(sig, next_fun, sf$nicename) ?>(<?R C(sd,sep=", ") ?>);
if ((this->getNodeType() & NODE_SYM<?%s ch[i] ?>) == NODE_<?%s autosym_name ?><?%s ch[i] ?>_minus) {
return <?%s paste0(sig, next_fun, sf$nicename) ?>(<?R C(sd_minus,sep=", ",float=FALSE,wrap.const=range_int) ?>);
}
}
<?R } ?>
Expand All @@ -444,9 +450,27 @@ CudaDeviceFunction real_t SymmetryAccess< PARENT >::<?%s paste0(this_fun, f$nice
template < class PARENT >
template <class N>
CudaDeviceFunction void SymmetryAccess< PARENT >::pop<?%s s$suffix ?>(N & node) const
{
parent::pop<?%s s$suffix ?>(node);
<?R resolve.symmetries(Density[s$load.densities,,drop=FALSE]) ?>
{ <?R
if (Options$autosym == 0) { ?>
parent::pop<?%s s$suffix ?>(node); <?R
} else if (Options$autosym == 1) { ?>
parent::pop<?%s s$suffix ?>(node); <?R
resolve.symmetries(Density[s$load.densities,,drop=FALSE])
} else if (Options$autosym == 2) { ?>
if (this->getNodeType() & (NODE_SYMX | NODE_SYMY | NODE_SYMZ)) { <?R
dens = Density;
dens$load = s$load.densities;
for (d in rows(dens)) if (d$load) {
f = rows(Fields)[[match(d$field, Fields$name)]]
dp = c(-d$dx, -d$dy, -d$dz) ?>
<?%s paste("node",d$name,sep=".") ?> = load_<?%s f$nicename ?>(range_int< <?%d dp[1] ?> >(),range_int< <?%d dp[2] ?> >(),range_int< <?%d dp[3] ?> >()); <?R
} else if (!is.na(d$default)) { ?>
<?%s paste("node",d$name,sep=".") ?> = <?%f d$default ?>; <?R
} ?>
} else {
parent::pop<?%s s$suffix ?>(node);
} <?R
} else stop("Unknown autosym option") ?>
}
<?R } } ?>

Expand Down
61 changes: 53 additions & 8 deletions src/conf.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ if (! SYMALGEBRA) {
library(symAlgebra,quietly=TRUE,warn.conflicts=FALSE)
}

if (is.null(Options$autosym)) Options$autosym = FALSE
if (is.null(Options$autosym)) Options$autosym = 0

#source("linemark.R")

Expand Down Expand Up @@ -370,7 +370,7 @@ if (is.null(Description)) {
}


if (Options$autosym) { ## Automatic symmetries
if (Options$autosym > 0) { ## Automatic symmetries
symmetries = data.frame(symX=c(-1,1,1),symY=c(1,-1,1),symZ=c(1,1,-1))

for (g in unique(DensityAll$group)) {
Expand Down Expand Up @@ -399,15 +399,60 @@ if (Options$autosym) { ## Automatic symmetries
Fields[sel,s] = Fields$name[sel]
}

AddNodeType("SymmetryX_plus", group="SYMX")
AddNodeType("SymmetryX_minus", group="SYMX")
AddNodeType("SymmetryY_plus", group="SYMY")
AddNodeType("SymmetryY_minus", group="SYMY")
rownames(Fields) = Fields$name

if (Options$autosym == 1) {
autosym_shift = 0
autosym_name = "Symmetry"
} else if (Options$autosym == 2) {
autosym_shift = 1
autosym_name = "SymmetryEdge"
} else stop("unknown autosym value")

directions = lapply(rows(Fields), function(f) expand.grid(dx=f$minx:f$maxx,dy=f$miny:f$maxy,dz=f$minz:f$maxz))
names(directions) = Fields$name
dir.sort = function(d) {
d = unique(d)
d[order(d[,3],d[,2],d[,1]),]
}
directions = lapply(directions,dir.sort)
tmp = NULL
while (!identical(directions, tmp)) {
tmp = directions
for (f in rows(Fields)) {
d = directions[[f$name]]
for (i in 1:3) {
nfn = f[[names(symmetries)[i]]]
od = directions[[nfn]]
cr = od
nd = d[d[,i] < 0,, drop=FALSE]
nd[,i] = -nd[,i] - autosym_shift
cr = rbind(cr,nd)
nd = d[d[,i] > 0,, drop=FALSE]
nd[,i] = -nd[,i] + autosym_shift
cr = rbind(cr,nd)
directions[[nfn]] = cr
}
}
directions = lapply(directions,dir.sort)
}
Fields$minx = sapply(directions, function(x) min(x$dx))
Fields$maxx = sapply(directions, function(x) max(x$dx))
Fields$miny = sapply(directions, function(x) min(x$dy))
Fields$maxy = sapply(directions, function(x) max(x$dy))
Fields$minz = sapply(directions, function(x) min(x$dz))
Fields$maxz = sapply(directions, function(x) max(x$dz))


AddNodeType(paste0(autosym_name, "X_plus"), group="SYMX")
AddNodeType(paste0(autosym_name, "X_minus"), group="SYMX")
AddNodeType(paste0(autosym_name, "Y_plus"), group="SYMY")
AddNodeType(paste0(autosym_name, "Y_minus"), group="SYMY")
if (all(range(Fields$minz,Fields$maxz) == c(0,0))) {
# we're in 2D
} else {
AddNodeType("SymmetryZ_plus", group="SYMZ")
AddNodeType("SymmetryZ_minus", group="SYMZ")
AddNodeType(paste0(autosym_name, "Z_plus"), group="SYMZ")
AddNodeType(paste0(autosym_name, "Z_minus"), group="SYMZ")
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/makefile.main.Rt
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ for (d in destinations) {
paste0(" ",
names(opts),
" = ",
ifelse(opts==0,"FALSE","TRUE"),
opts,
ifelse(seq_along(opts) != length(opts),",","")
)
)
Expand Down
18 changes: 15 additions & 3 deletions src/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,23 @@ get.models = function() {
opts_terms = terms(opts)
opts = attr(opts_terms,"factors")
opts = data.frame(t(opts))
rownames(opts) = paste(name,gsub(":","_",rownames(opts)),sep="_")
opts[] = opts > 0
if (attr(opts_terms, "intercept") == 1) {
opts[name,]=0
opts = rbind(opts, FALSE)
}
# opts = apply(opts, 2, function(x) x > 0)
if ("autosym" %in% names(opts)) {
opts$autosym = ifelse(opts$autosym, 1, 0)
x = opts[opts$autosym > 0,,drop=FALSE]
x$autosym = 2
opts = rbind(opts, x)
}
rownames(opts) = sapply(seq_len(nrow(opts)), function(i) {
x = opts[i,]
w = names(opts)
w = paste0(w, ifelse(x>1,x,""))
w = c(name, w[x>0])
paste(w,collapse="_")
})
} else {
opts = data.frame(row.names=name)
}
Expand Down

0 comments on commit 68e4df6

Please sign in to comment.