For most users, the vanilla function `calc_weight`

should
be sufficient for most use cases with binary treatments. However, for
users with more complicated data structures or problems, the
`causalOT`

package offers a more flexible interface heavily
reliant on the `torch`

package. We will walk through a few
use cases here to show how one might use the object-oriented programming
(OOP) objects.

One **very** important thing to note is that these
objects are mutable; in other words, they are always passed by reference
so changes to the base objects will effect all other objects they are
reliant on. Thus, changes will propagate forward and backward. For these
reasons, these objects will be more dangerous in terms of side-effects
and should be used carefully.

Finally, these data structures are heavy reliant on the
`torch`

package in `R`

. This allows relatively
easy use of GPUs and also has other advantages such as passing by
reference and various optimization methods available by default.

The fundamental objects underlying the OOP software methods in the
package is an `R6`

class `Measure`

. These objects
are named for the fact that they specify an empirical distribution on a
set of support points. In light of this, the first two arguments should
be intuitive: `x`

, the set of data for the measure, and
`weights`

, the empirical mass.

```
n <- 5
d <- 3
x <- matrix(stats::rnorm(n*d), nrow = n)
w <- stats::runif(n)
w <- w/sum(w)
m <- Measure(x = x, weights = w)
```

We can also view the weights and data of a measure object by accessing these public fields:

```
m$x
#> torch_tensor
#> -0.1088 0.1408 -0.9207
#> 0.0217 -1.5297 -0.2621
#> 1.0911 0.1324 -0.2298
#> -0.0686 0.5114 -0.0258
#> 0.7843 1.1124 -1.3936
#> [ CPUDoubleType{5,3} ]
m$weights
#> torch_tensor
#> 0.2228
#> 0.1881
#> 0.2386
#> 0.2295
#> 0.1211
#> [ CPUDoubleType{5} ]
```

The next argument in the constructor function,
`probability.measure`

, lets the function know if your weights
are a probability measure—i.e., the weights sum to 1 and are
positive—versus a more general type of measure. The default assumption
is that you are using a probability measure.

Then we come to a very important argument: `adapt`

. This
let’s the function know if you are seeking to change nothing (“none”)
and keep the measure static, if you want to adapt the weights
(“weights”) towards another measure, or if you want to move the data
points of the measure itself (“x”).

**Adapting the measure to functions of target data.**
The next two arguments are useful in the setting when you want to adapt
specific functions of the `Measure`

to target data.
Typically, these target functions will be the empirical means of some
aspect of the covariates in a target data set. As an example:

```
target.data <- matrix(rnorm(n*d),n,d)
target.values <- colMeans(target.data)
m <- Measure(x = x, weights = w,
probability.measure = TRUE,
adapt = "weights",
target.values = target.values)
```

Note that if we don’t supply the `balance.functions`

argument and `target.values`

are provided, the function will
use the data in `x`

as the `balance.functions`

. We
can view are balance functions with the following arguments:

Note that the values returned are different than the original. This is because the software divides the balance functions and target values by the standard deviation of the balance function.

```
all.equal(as.numeric(m$balance_target), target.values)
#> [1] "Mean relative difference: 0.396385"
all.equal(as.matrix(m$balance_functions), x)
#> [1] "Mean relative difference: 0.3160986"
sds <- apply(x,2,sd)
all.equal(as.numeric(m$balance_target), target.values/sds)
#> [1] TRUE
all.equal(as.matrix(m$balance_functions), sweep(x,2,sds,"/"))
#> [1] TRUE
```

Obviously, if `adapt = "none"`

, then the
`balance.functions`

and `target.values`

are
essentially useless.

Finally, the arguments `dtype`

and `device`

are
arguments for setting of the the `torch_tensor`

s of the
underlying data structures. For more information, see the
`torch`

documentation.

Also, we can print the measure objects to the screen to see some of the underlying information quickly. We also get the object address which can be useful in distinguishing the different objects.

```
m
#> Measure: 0x7fdd8b7c4eb8
#> x : a 5x3 matrix
#> -0.11, 0.14, -0.92
#> 0.02, -1.53, -0.26
#> 1.09, 0.13, -0.23
#> -0.07, 0.51, -0.03
#> 0.78, 1.11, -1.39
#> weights: 0.22, 0.19, 0.24, 0.23, 0.12
#> balance:
#> funct.: -0.2, 0.14, -1.61
#> target: 0.55, -0.2, -1.53 …
#> adapt : weights
#> dtype : torch_Double
#> device : torch_device(type='cpu')
```

The next important component of the OOP framework in
`causalOT`

are the `OTProblem`

objects. Say we
have to measures, one we want to target, `m_target`

, and one
we want to adapt to the target measure, `m_source`

by
changing its weights.

```
m_target <- Measure(x = matrix(rnorm(n*2*d), n*2,d))
m_source <- Measure(x = x, weights = w, adapt = "weights")
```

Now we need some way of adapting `m_source`

and in this
package, we will use optimal transport methods. Thus, we specify our
optimal transport problem:

The `OTProblem`

is the basis for setting up the following
objective function \[
\begin{align*}
w^\star &= \operatorname{argmin}_w
OT_\lambda(m_{\text{source}}(w),m_{\text{target}}) \\
& \text{s.t. } \frac{\mathbb{E}_w(B(x_{\text{source}}))
- \mathbb{E}(B(x_{\text{target}})) }{\sigma} \leq \delta.
\end{align*}
\] \(OT_\lambda\) is an optimal
transport distance specified by the Sinkhorn distance: \[S_\lambda(a, b) = \min_P \langle C, P \rangle +
\lambda \langle P, log(P) \rangle - \lambda, s.t. P \mathbb{1} = a,
P^\top \mathbb{1} = b,\] for some cost matrix \(C_{i,j} = c(x_i, x_j\), or the Sinkhorn
divergence: \[ S_\lambda(a,b) - 0.5
S_\lambda(a,a) - 0.5 S_\lambda(b,b). \] The linear constraint on
the problem bounds the balance functions within some distance \(\delta\) of their original standard
deviation, \(\sigma\).

With this detail, we then need to specify which optimal transport
problem we’re using, the various penalty parameters, etc. to do this, we
use the `setup_arguments`

function below:

```
otp$setup_arguments(
lambda = NULL, # penalty values of the optimal transport (OT) distances to try
delta = NULL, # constraint values to try for balancing functinos
grid.length = 7L, # number of values of lambda and delta to try
# if none are provided
cost.function = NULL, # the ground cost to use between covariates
# default is the Euclidean distance
p = 2, # power to raise the cost by
cost.online = "auto", #Should cost be calculated "online" or "tensorized" (stored in memory). "auto" will try to decide for you
debias = TRUE, # use Sinkhorn divergences (debias = TRUE), i.e. debiased Sinkhorn distances,
# or use the Sinkhorn distances (debias = FALSE)
diameter = NULL, # the diameter of the covariate space if known
ot_niter = 1000L, # the number of iterations to run when solving OT distances
ot_tol = 0.001 # the tolerance for convergance of OT distances
)
```

The last two arguments may be confusing at first but understanding
how the `OTProblem`

objects adapt the measure may help to add
some clarity. The `OTProblem`

first has to solve an optimal
transport problem between the two measures (with runtime parameters
specified in the `setup_arguments`

function). Then the object
will take a step of updating the weights, which is done by the next
function.

Once we have set up the arguments, we can solve this
`OTProblem`

:

```
otp$solve(
niter = 1000L, # maximum number of iterations
tol = 1e-5, # tolerance for convergence
optimizer = "torch", # which optimizer to use "torch" or "frank-wolfe"
torch_optim = torch::optim_lbfgs, # torch optimizer to use if required
torch_scheduler = torch::lr_reduce_on_plateau, # torch scheduler to use if required
torch_args = list(line_search_fn = "strong_wolfe"), # args passed to the torch functions,
osqp_args = NULL, #arguments passed to the osqp solver used for "frank-wolfe" and balance functions
quick.balance.function = TRUE # if balance functions are also present, should an approximate value of the hyperparameter "delta" be found first
)
```

Since the objects are passed by reference, the weights of the measure object that was adapted are now different.

```
#> adapted original
#> [1,] 3.919938e-01 0.2227984
#> [2,] 1.150060e-01 0.1880527
#> [3,] 1.802888e-09 0.2385908
#> [4,] 4.930001e-01 0.2295054
#> [5,] 6.598182e-08 0.1210527
```

*Note:* the dual optimization method currently available for
the `COT`

method in the `calc_weight`

function is
not implemented for `OTProblem`

objects. Thus, these
optimization problems will possibly take longer to solve.

We have run the function with a variety of `lambda`

parameters chosen by the `OTProblem`

object. We should select
one to move forward with.

```
otp$choose_hyperparameters(
n_boot_lambda = 100L, #Number of bootstrap iterations to choose lambda
n_boot_delta = 1000L, #Number of bootstrap iterations to choose delta
lambda_bootstrap = Inf # penalty parameter to use for OT distances
)
```

The `delta`

parameter wasn’t used so we only select the
values of `lambda`

. This gives us a final value of lambda
of

and final weights of

```
as.numeric(m_source$weights)
#> [1] 2.489658e-01 1.233070e-01 1.145063e-09 6.277272e-01 4.190680e-08
```

We can also see the final value of the optimal transport problem with
the chosen value of `lambda`

and weights.

In summary, we have the following steps to solve our causal inference problems using optimal transport.

- Construct the
`Measure`

objects

```
m_target <- Measure(x = matrix(rnorm(n*2*d), n*2,d))
m_source <- Measure(x = x, weights = w, adapt = "weights")
```

- Construct the
`OTProblem`

- Setup the arguments of the
`OTProblem`

```
otp$setup_arguments(
lambda = NULL, # penalty values of the optimal transport (OT) distances to try
delta = NULL, # constraint values to try for balancing functinos
grid.length = 7L, # number of values of lambda and delta to try
# if none are provided
cost.function = NULL, # the ground cost to use between covariates
# default is the Euclidean distance
p = 2, # power to raise the cost by
cost.online = "auto", #Should cost be calculated "online" or "tensorized" (stored in memory). "auto" will try to decide for you
debias = TRUE, # use Sinkhorn divergences (debias = TRUE), i.e. debiased Sinkhorn distances,
# or use the Sinkhorn distances (debias = FALSE)
diameter = NULL, # the diameter of the covariate space if known
ot_niter = 1000L, # the number of iterations to run when solving OT distances
ot_tol = 0.001 # the tolerance for convergance of OT distances
)
```

- Solve the
`OTProblem`

```
otp$solve(
niter = 1000L, # maximum number of iterations
tol = 1e-5, # tolerance for convergence
optimizer = "torch", # which optimizer to use "torch" or "frank-wolfe"
torch_optim = torch::optim_lbfgs, # torch optimizer to use if required
torch_scheduler = torch::lr_reduce_on_plateau, # torch scheduler to use if required
torch_args = list(line_search_fn = "strong_wolfe"), # args passed to the torch functions,
osqp_args = NULL, #arguments passed to the osqp solver used for "frank-wolfe" and balance functions
quick.balance.function = TRUE # if balance functions are also present, should an approximate value of the hyperparameter "delta" be found first
)
```

- Select hyperparameter values (if needed)

```
otp$choose_hyperparameters(
n_boot_lambda = 100L, #Number of bootstrap iterations to choose lambda
n_boot_delta = 1000L, #Number of bootstrap iterations to choose delta
lambda_bootstrap = Inf # penalty parameter to use for OT distances
)
```

- You’re done!

The case above was simply a vanilla optimal transport problem that
could easily be solved by the `calc_weight`

function in the
main package. Let’s look at a more complicated use case.

Ideally, we’d simply dump the data together and run our OT framework.

```
nrow <- 100
ncol <- 2
a <- Measure(x = matrix(rnorm(nrow*ncol,mean=c(0.1,0.1)) + 0.1,nrow,ncol,byrow = TRUE), adapt = "weights")
b <- Measure(x = matrix(rnorm(nrow*ncol,mean=c(-0.1,-0.1),sd=0.25),nrow,ncol,byrow = TRUE), adapt = "weights")
c <- Measure(x = matrix(rnorm(nrow*ncol,mean=c(0.1,-0.1)),nrow,ncol,byrow = TRUE), adapt = "weights")
d <- Measure(x = matrix(rnorm(nrow*ncol,mean= c(-0.1,0.1),sd=0.25),nrow,ncol,byrow = TRUE), adapt = "weights")
overall <- Measure(x = torch::torch_vstack(lapply(list(a,b,c,d), function(meas) meas$x)),
adapt = "none")
overall_ot <- OTProblem(a,overall) + OTProblem(b, overall) +
OTProblem(c, overall) + OTProblem(d, overall)
```

One thing to note that’s kind of cool is that we can add our
`OTProblem`

objects together to make a unified objective
function.

```
overall_ot
#> OT Problem:
#> OT(0x7fdd7a0af5d0, 0x7fdd7a2b50a0) +
#> OT(0x7fdd7a15d830, 0x7fdd7a2b50a0) +
#> OT(0x7fdd7a1e1798, 0x7fdd7a2b50a0) +
#> OT(0x7fdd79b17810, 0x7fdd7a2b50a0)
```

Neat!

We can also run the `calc_weight`

function in each
treatment group targeting the overall population

```
source_measures <- list(a,b,c,d)
meas <- x_temp <- NULL
z_temp <- c(rep(1, nrow*4), rep(0,nrow))
wt <- list()
for(i in seq_along(source_measures)) {
meas <- source_measures[[i]]
x_temp <- as.matrix(torch::torch_vstack(list(overall$x,meas$x)))
wt[[i]] <- calc_weight(x = x_temp,
z = z_temp,
estimand = "ATT",
method = "COT")
}
```

If only moments are available, then each site can run essentially independently. We just need to collect the moments from each site and combine

```
target.values <-
as.numeric(a$x$mean(1) + b$x$mean(1) +
c$x$mean(1) + d$x$mean(1))/4
a_t <- Measure(x = a$x, adapt = "weights",
target.values = target.values)
b_t <- Measure(x = a$x, adapt = "weights",
target.values = target.values)
c_t <- Measure(x = a$x, adapt = "weights",
target.values = target.values)
d_t <- Measure(x = a$x, adapt = "weights",
target.values = target.values)
all.target.measures <- list(a_t, b_t, c_t, d_t)
```

Then we can optimize the weights targeting the moments in a bit of a hacky way.

```
ot_targ <- NULL
for(meas in all.target.measures) {
ot_targ <- OTProblem(meas, meas)
ot_targ$setup_arguments(lambda = 100)
ot_targ$solve(torch_optim = torch::optim_lbfgs,
torch_args = list(line_search_fn = "strong_wolfe"))
}
```

And we can check the final balance

```
final.bal <- as.numeric(a_t$x$mT()$matmul(a_t$weights$detach()))
original <- as.numeric(a_t$x$mean(1))
rbind(original,
`final balance` = final.bal,
`target values` = target.values)
#> [,1] [,2]
#> original 0.08015652 0.31129337
#> final balance -0.01093957 0.09104138
#> target values -0.01100336 0.09096446
```

This will target the moments without information about the underlying distributions. Obviously, we would prefer to use more of the available information, as we describe next.

Say we can pass any amount of data but are limited by the fact that privacy or other restrictions prevent us from sharing the full data at each site. We can instead construct a pseudo-overall population using Wasserstein Barycenters. These construct average distributions. Let’s see how it might work.

In this option, we pass gradients back to the main site. From this, we can construct a pseudo average population. Let’s see how it might work. We first construct our pseudo data.

Importantly, each data point must be initialized to a separate value otherwise all of the points will move together Then we pass this pseudo data and set up a problem at each site.

```
pseudo_a <- pseudo$detach()
pseudo_b <- pseudo$detach()
pseudo_c <- pseudo$detach()
pseudo_d <- pseudo$detach()
pseudo_a$requires_grad <- pseudo_b$requires_grad <-
pseudo_c$requires_grad <- pseudo_d$requires_grad <- "x"
ota <- OTProblem(a$detach(), # don't update a
pseudo_a)
otb <- OTProblem(b$detach(), # don't update b
pseudo_b)
otc <- OTProblem(c$detach(), # don't update c
pseudo_c)
otd <- OTProblem(d$detach(), # don't update c
pseudo_d)
```

Then we setup the arguments. For simplicity, we will set
`lambda = 0.1`

.

```
ota$setup_arguments(lambda = .1)
otb$setup_arguments(lambda = .1)
otc$setup_arguments(lambda = .1)
otd$setup_arguments(lambda = .1)
```

Then we setup our optimizer at the main site

```
opt <- torch::optim_rmsprop(pseudo$x)
sched <- torch::lr_multiplicative(opt, lr_lambda = function(epoch) {0.99})
```

Then we run our optimization loop like so:

```
#optimization loop
for (i in 1:100) {
# zero grad of main optimizer
opt$zero_grad()
# get gradients at each site
ota$loss$backward()
otb$loss$backward()
otc$loss$backward()
otd$loss$backward()
# pass grads back to main site
pseudo$grad <- pseudo_a$grad + pseudo_b$grad +
pseudo_c$grad + pseudo_d$grad
# update pseudo data at main site
opt$step()
# zero site gradients
torch::with_no_grad({
pseudo_a$grad$copy_(0.0)
pseudo_b$grad$copy_(0.0)
pseudo_c$grad$copy_(0.0)
pseudo_d$grad$copy_(0.0)
})
# update scheduler
sched$step()
}
```

Then we pass the final pseudo data back to the sites and optimize the weights at each site:

```
pseudo_a$x <- pseudo_b$x <- pseudo_c$x <- pseudo_d$x <-
pseudo$x
ota_w <- OTProblem(a, pseudo_a$detach())
otb_w <- OTProblem(b, pseudo_b$detach())
otc_w <- OTProblem(c, pseudo_c$detach())
otd_w <- OTProblem(d, pseudo_d$detach())
ota_w$setup_arguments()
ota_w$solve(torch_args = list(line_search_fn = "strong_wolfe"))
ota_w$choose_hyperparameters()
otb_w$setup_arguments()
otb_w$solve(torch_args = list(line_search_fn = "strong_wolfe"))
otb_w$choose_hyperparameters()
otc_w$setup_arguments()
otc_w$solve(torch_args = list(line_search_fn = "strong_wolfe"))
otc_w$choose_hyperparameters()
otd_w$setup_arguments()
otd_w$solve(torch_args = list(line_search_fn = "strong_wolfe"))
otd_w$choose_hyperparameters()
```

Note we haven’t checked for convergence when constructing the pseudo-data to save time. You should, however, do this in your own work.

Of course, maybe we can’t pass gradients. Instead, we can create pseudo data at each site.

The second option is to create pseudo-data for each site and then use this to generate an overall average data set. This will allow us to create privacy respecting pseudo-data in each site, i.e., data that is close to the population at A but with different values. Then we take these pseudo-data and create an overall average data set like follows.

First, we need to reinitialize sites again since they were changed in the previous example.

```
a$weights <- a$init_weights
b$weights <- b$init_weights
c$weights <- c$init_weights
d$weights <- d$init_weights
```

Then again create pseudo data

Then we pass this pseudo data and set up a problem at each site.

```
pseudo_a <- pseudo$detach()
pseudo_b <- pseudo$detach()
pseudo_c <- pseudo$detach()
pseudo_d <- pseudo$detach()
pseudo_a$requires_grad <- pseudo_b$requires_grad <-
pseudo_c$requires_grad <- pseudo_d$requires_grad <- "x"
ota <- OTProblem(a$detach(), # don't update a
pseudo_a)
otb <- OTProblem(b$detach(), # don't update b
pseudo_b)
otc <- OTProblem(c$detach(), # don't update c
pseudo_c)
otd <- OTProblem(d$detach(), # don't update c
pseudo_d)
```

and setup the arguments. For simplicity, we will set
`lambda = 0.1`

.

```
ota$setup_arguments(lambda = .1)
otb$setup_arguments(lambda = .1)
otc$setup_arguments(lambda = .1)
otd$setup_arguments(lambda = .1)
```

Then we solve for the barycenters.

```
# run separately at each site
ota$solve(torch_optim = torch::optim_rmsprop)
otb$solve(torch_optim = torch::optim_rmsprop)
otc$solve(torch_optim = torch::optim_rmsprop)
otd$solve(torch_optim = torch::optim_rmsprop)
```

Looking at the pseudo data in group B we can see that the pseudo-data is a much better approximation to B after optimization.

Then we send the pseudo-data back to our main site to create the overall pseudo-data

```
# send back to the main site and create overall problem
ot_overall <-
OTProblem(pseudo_a$detach(),
pseudo) +
OTProblem(pseudo_b$detach(),
pseudo) +
OTProblem(pseudo_c$detach(),
pseudo) +
OTProblem(pseudo_d$detach(),
pseudo)
ot_overall$setup_arguments(lambda = 0.1)
ot_overall$solve(torch_optim = torch::optim_rmsprop)
```

Now we have an average population to target at each site, which we can do like so:

```
# pass pseudo to each site then setup the problems again
ota2 <- OTProblem(a,
pseudo$detach())
otb2 <- OTProblem(b, # don't update b
pseudo$detach())
otc2 <- OTProblem(c,
pseudo$detach())
otd2 <- OTProblem(d,
pseudo$detach())
all.problems <- list(ota2,
otb2,
otc2,
otd2)
# then we optimize the weights at each site separately.
for (prob in all.problems) {
prob$setup_arguments()
prob$solve(
torch_optim = torch::optim_lbfgs,
torch_args = list(line_search_fn = "strong_wolfe")
)
prob$choose_hyperparameters()
}
```

The final example is a situation where we may have covariate, treatment, and outcome data at one location and want to use it to infer effects in another population with only covariate data. Say we have a binary treatment at our source site and only moments available from the target site.

```
x_1 <- matrix(rnorm(128*2),128) +
matrix(c(-0.1,-0.1), 128, 2,byrow = TRUE)
x_2 <- matrix(rnorm(256*2), 256) +
matrix(c(0.1,0.1), 256, 2,byrow = TRUE)
target.data <- matrix(rnorm(512*2), 512, 2) * 0.5 +
matrix(c(0.1,-0.1), 512, 2, byrow = TRUE)
constructor.formula <- formula("~ 0 + . + I(V1^2) + I(V2^2)")
target.values <- colMeans(model.matrix(constructor.formula,
as.data.frame(target.data)))
m_1 <- Measure(x = x_1, adapt = "weights",
balance.functions = model.matrix(constructor.formula,
as.data.frame(x_1)),
target.values = target.values)
m_2 <- Measure(x = x_2, adapt = "weights",
balance.functions = model.matrix(constructor.formula,
as.data.frame(x_2)),
target.values = target.values)
ot_binary <- OTProblem(m_1, m_2)
```

In this case, we’d like the treatment groups to have the same distributions but have the same first and second moments from our target site.

```
ot_binary$setup_arguments()
ot_binary$solve(torch_optim = torch::optim_lbfgs,
torch_args = list(line_search_fn = "strong_wolfe"))
ot_binary$choose_hyperparameters()
```

We now checkt to see how everything looks using the
`info()`

function.

```
info <- ot_binary$info()
names(info)
#> [1] "loss" "iterations"
#> [3] "balance.function.differences" "hyperparam.metrics"
```

We can see a variety of things like the metrics from the hyperparameter selection, iterations run, final loss, etc. We can also see how the balance functions are doing in terms of targeting the moments.

```
info$balance.function.differences
#> $`0x7fdd8cdf5660`
#> $`0x7fdd8cdf5660`$balance
#> torch_tensor
#> 0.001 *
#> -3.3433
#> 3.5927
#> -2.4325
#> 3.5927
#> [ CPUDoubleType{4} ][ grad_fn = <SubBackward0> ]
#>
#> $`0x7fdd8cdf5660`$delta
#> [1] 1e-04
#>
#>
#> $`0x7fdd9e7e9d60`
#> $`0x7fdd9e7e9d60`$balance
#> torch_tensor
#> 0.001 *
#> 1.8163
#> -0.3774
#> 1.8235
#> -1.8946
#> [ CPUDoubleType{4} ][ grad_fn = <SubBackward0> ]
#>
#> $`0x7fdd9e7e9d60`$delta
#> [1] 1e-04
```

It appears that all of our balance functions are less than the desired tolerance. Finally, the optimal transport distance between treatments 1 and 2 is also improved:

We have demonstrated a variety of examples here. Hopefully we have
made it clear that you can also do regular optimal transport barycenters
even in the case where causal inference isn’t the goal. You can even use
the `OTProblem`

to solve optimal transport problems when
there are no weights or data to adapt.