Inverse Probability Weighting
$$
\[ \begin{aligned} \tilde\mu &\overset{\texttip{\small{\unicode{x2753}}}{it minimizes the average of the squared errors we make on the population}}{=} \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \frac{1}{m}\sum_{j=1}^m \qty{ y_j - m(w_j,x_j) }^2 \\ &\overset{\texttip{\small{\unicode{x2753}}}{this average can be rewritten as the expected squared error for $(W_i,X_i,Y_i)$ drawn uniformly-at-random from the population}}{=} \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \mathop{\mathrm{E}}\qty[ \qty{ Y_i - m(W_i,X_i) }^2 ] \\ \end{aligned} \]
\[ \tilde\mu(w,x) = \mathop{\mathrm{E}}[ \hat\mu(w,x) ] \qfor \hat\mu = \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \frac{1}{n}\sum_{i=1}^n \qty{ Y_i - m(W_i,X_i) }^2. \]
\[ \begin{aligned} \tilde\mu &= \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \sum_{j=1}^n \left\{ y_j - m(w_j,x_j) \right\}^2 \\ &= \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \sum_{j=1}^n \left\{ \mu(w_j,x_j) - m(w_j,x_j) \right\}^2. \end{aligned} \]
\[ Y_i = \mu(W_i,X_i) + \varepsilon_i \qqtext{ where } \mathop{\mathrm{E}}[\varepsilon_i \mid W_i, X_i] = 0. \]
Intuition and Empirical Evidence.
\[ \color{gray} \begin{aligned} \Delta_{\text{raw}} &= \textcolor[RGB]{0,191,196}{\frac{1}{m_1}\sum_{j:w_j=1} y_j} - \textcolor[RGB]{248,118,109}{\frac{1}{m_0}\sum_{j:w_j=0} y_j} &&\overset{\texttip{\small{\unicode{x2753}}}{translated into histogram form}}{=} \quad \sum_{x} \textcolor[RGB]{0,191,196}{p_{x \mid 1} \ \mu(1,x)} - \sum_{x} \textcolor[RGB]{248,118,109}{p_{x \mid 0} \ \mu(0,x)} \\ \tilde\Delta_{0} &= \textcolor[RGB]{248,118,109}{\frac{1}{m_0}\sum_{j:w_j=0}} \color{gray}\left\{ \color{black} \textcolor[RGB]{0,191,196}{\tilde \mu(1,x)} - \textcolor[RGB]{248,118,109}{\tilde\mu(0,x) } \color{gray} \right\} \color{black} &&\overset{\texttip{\small{\unicode{x2753}}}{translated into histogram form}}{=} \quad \sum_x \textcolor[RGB]{248,118,109}{p_{x \mid 0}} \ \textcolor[RGB]{0,191,196}{\tilde \mu(1,x)} - \sum_x \textcolor[RGB]{248,118,109}{p_{x \mid 0}} \ \textcolor[RGB]{248,118,109}{\tilde\mu(0,x) }. \end{aligned} \]
How does the raw difference in means \(\Delta_{\text{raw}}\) compare to the adjusted difference \(\tilde\Delta_{0}\)
when our adjustment uses the least squares predictor in the all functions model?
See Slide 4.1 for the answer.
How does the raw difference in means \(\Delta_{\text{raw}}\) compare to the adjusted difference \(\tilde\Delta_{0}\)
when our adjustment uses the least squares predictor in the horizontal lines model?
See Slide 4.2 for the answer.
Why, in intuitive terms, is the constant \(\tilde\mu(1,x)\) not the right prediction to use in this comparison? Hint. Think about the extreme covariate shift case. Which green dots does fit well? Which does it not fit well? Why?
See Slide 4.3 for the answer.
What we want \(\tilde\mu(1,x)\) to be good at (when we’re using it in \(\tilde\Delta_{0}\)) and what we’re asking it to be good at (when we’re doing least squares) are different things. When there’s a lot of covariate shift, they’re very different things.
Let’s think about what the population least squares predictor \(\tilde\mu(w,x)\) looks like in histogram form. \[ \begin{aligned} \tilde\mu = \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \ \text{MSE}(m) \qfor \text{MSE}(m) &= \frac{1}{m}\sum_{wx}\sum_{j:w_j=w, x_j=x} \color{gray}\left\{ \color{black} \mu(w_j,x_j) - m(w_j,x_j) \color{gray} \right\} \color{black}^2 \\ &= \frac{1}{m}\sum_{wx}\sum_{j:w_j=w, x_j=x} \color{gray}\left\{ \color{black} \mu(w, x) - m(w,x) \color{gray} \right\} \color{black}^2 \\ &= \frac{1}{m}\sum_{wx} m_{wx} \color{gray}\left\{ \color{black} \mu(w, x) - m(w,x) \color{gray} \right\} \color{black}^2 \\ \end{aligned} \]
To pretend that we have the same number of green dots as red dots in each column …
Equivalently, we weight each term by a factor of \(m_{0x}/m_{wx}\). We’ll call this weight \(\gamma(w,x)\). \[ \begin{aligned} \tilde\mu^{\text{IPW}} = \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \ \text{WMSE}(m) \qfor \text{WMSE}(m) &= \frac{1}{m}\sum_{wx} \textcolor[RGB]{248,118,109}{m_{0x}} \color{gray}\left\{ \color{black} \mu(w, x) - m(w,x) \color{gray} \right\} \color{black}^2 \\ &= \frac{1}{m}\sum_{wx} \gamma(w,x) m_{wx}\color{gray}\left\{ \color{black} \mu(w, x) - m(w,x) \color{gray} \right\} \color{black}^2 \\ \qfor &\gamma(w,x) = \frac{m_{0x}}{m_{wx}} = \begin{cases} 1 & \text{if } w=0 \\ \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{\textcolor[RGB]{0,191,196}{m_{1x}}} & \text{if } w=1 \end{cases} \end{aligned} \]
And if we want to think about this as a sum over the population …
we can just give each person in our sum the weight \(\gamma(w_j,x_j)\) for their column.
\[ \begin{aligned} \tilde\mu^{\text{IPW}} = \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \ \text{WMSE}(m) \qfor \text{WMSE}(m) &= \frac{1}{m}\sum_{j=1}^m \gamma(w_j,x_j) \color{gray}\left\{ \color{black} \mu(w_j, x_j) - m(w_j,x_j) \color{gray} \right\} \color{black}^2 \end{aligned} \]
\[ \begin{aligned} \hat\Delta_0 &= \frac{1}{m_0} \sum_{j:w_j=0} \qty{ \hat\mu(1,x_j) - \hat\mu(0,x_j) } \\ & \qfor \hat\mu = \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \frac{1}{n}\sum_{i=1}^n \gamma(W_i,X_i) \color{gray}\left\{ \color{black} Y_i - m(W_i,X_i) \color{gray} \right\} \color{black}^2 \\ & \qand \gamma(w,x) = \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{m_{wx}} = \begin{cases} 1 & \text{if } w=0 \\ \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{\textcolor[RGB]{0,191,196}{m_{1x}}} & \text{if } w=1 \end{cases} \end{aligned} \]
pop.summaries = pop |> group_by(w,x) |> summarize(mwx=n(), .groups='drop')
mwx = summary.lookup('mwx', pop.summaries)
gamma = function(w,x) mwx(0,x) / mwx(1,x)
sam$weights = gamma(sam$w, sam$x)
fitted.model = lm(y~1+w, weights=weights, data=sam)
muhat = function(w,x) predict(fitted.model, newdata=data.frame(w=w,x=x))
Delta0.hat = mean(muhat(1,pop$x) - muhat(0,pop$x))
Delta0.hat
group_by
and summarize
and the function summary.lookup
from the Week 9 homework.
lm
, telling it to use the weights we’ve put in the column sam$weights
.
[1] -10740.82
\[ \hat\gamma(w,x) = \frac{\textcolor[RGB]{248,118,109}{N_{0x}}}{N_{wx}} = \begin{cases} 1 & \text{if } w=0 \\ \frac{\textcolor[RGB]{248,118,109}{N_{0x}}}{\textcolor[RGB]{0,191,196}{N_{1x}}} & \text{if } w=1 \end{cases} \qfor N_{wx} = \sum_{i:W_i=w, X_i=x} 1 \]
sam.summaries = sam |> group_by(w,x) |> summarize(Nwx=n(), .groups='drop')
Nwx = summary.lookup('Nwx', sam.summaries)
gammahat = function(w,x) Nwx(0,x) / Nwx(1,x)
sam$weights = gammahat(sam$w, sam$x)
fitted.model = lm(y~1+w, weights=weights, data=sam)
muhat = function(w,x) predict(fitted.model, newdata=data.frame(w=w,x=x))
Delta0.hat = mean(muhat(1,sam$x) - muhat(0,sam$x))
Delta0.hat
lm
, telling it to use the weights we’ve put in the column sam$weights
.
[1] -11617.36
\[ \mathop{\mathrm{E}}[\hat\Delta_0] = \Delta_0 \qfor \Delta_0 = \frac{1}{m_0} \sum_{j:w_j=0} \qty{ \mu(1,x_j) - \mu(0,x_j) } \]
An Exercise.
How should we weight to estimate \(\Delta_0\)?3
How should we weight to estimate \(\Delta_1\)?4
How should we weight to estimate \(\Delta_{\text{all}}\)?5
Suppose we have this inverse probability weighted population least squares predictor \(\tilde\mu\). \[ \begin{aligned} \tilde\mu= \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \ \text{WMSE}(m) \qfor \text{WMSE}(m) &= \sum_{wx} \gamma(w,x) m_{wx} \color{gray}\left\{ \color{black} \mu(w, x) - m(w,x) \color{gray} \right\} \color{black}^2 \\ \qand &\gamma(w,x) = \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{m_{wx}}. \end{aligned} \]
We want to show that when we plug it in to our formula for \(\tilde\Delta_0\), we get our estimation target \(\Delta_0\). \[ \begin{aligned} \tilde\Delta_0 &= \frac{1}{m_0} \sum_{j:w_j=0} \qty{ \tilde\mu(1,x_j) - \tilde\mu(0,x_j) } && \text{ the population version of our estimator} \\ \Delta_0 &= \frac{1}{m_0} \sum_{j:w_j=0} \qty{ \mu(1,x_j) - \mu(0,x_j) } && \text{ our target} \end{aligned} \]
\[ \begin{aligned} \tilde\Delta_0 &= \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}} \textcolor[RGB]{0,191,196}{\tilde\mu(1,x_j)} - \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}} \textcolor[RGB]{0,191,196}{\tilde\mu(0,x_j)} \\ \Delta_0 &= \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}} \textcolor[RGB]{0,191,196}{\mu(1,x_j)} - \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}} \textcolor[RGB]{0,191,196}{\mu(0,x_j)} \end{aligned} \]
\[ \Delta_0-\tilde\Delta_0 = \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}} \color{gray}\left\{ \color{black} \textcolor[RGB]{0,191,196}{\mu(1,x_j) - \tilde\mu(1,x_j)} \color{gray} \right\} \color{black} - \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}}\color{gray}\left\{ \color{black} \textcolor[RGB]{248,118,109}{\mu(0,x_j)} - \textcolor[RGB]{248,118,109}{\tilde\mu(0,x_j)} \color{gray} \right\} \color{black}. \]
\[ \begin{aligned} 0 &\overset{\texttip{\small{\unicode{x2753}}}{histogram form}}{=} \sum_{wx} \textcolor[RGB]{248,118,109}{m_{0x}} \color{gray}\left\{ \color{black} \mu(w, x) - \tilde\mu(w,x) \color{gray} \right\} \color{black} m(w,x) \\ &\overset{\texttip{\small{\unicode{x2753}}}{summing over $x$, then $w$}}{=} \sum_{w \in 0,1} \sum_x \textcolor[RGB]{248,118,109}{m_{0x}} \color{gray}\left\{ \color{black} \mu(w, x) - \tilde\mu(w,x) \color{gray} \right\} \color{black} m(w,x) \\ &\overset{\texttip{\small{\unicode{x2753}}}{rewriting the histogram-form sum over $x$ as a sum over the population}}{=} \sum_{w \in 0,1} \sum_{j:w_j=0} \color{gray}\left\{ \color{black} \mu(w, x_j) - \tilde\mu(w,x_j) \color{gray} \right\} \color{black} m(w,x_j) \qqtext{ for all } m \in \mathcal{M}. \end{aligned} \tag{1}\]
\[ \begin{aligned} 0 &= \frac{d}{dt}\mid_{t=0} \text{WMSE}(\tilde \mu + t m) \\ &\overset{\texttip{\small{\unicode{x2753}}}{plugging in the definition of WMSE and distributing the derivative over terms in the sum}}{=} \sum_{wx} \gamma(w,x) m_{wx} \times \frac{d}{dt}\mid_{t=0} \color{gray}\left\{ \color{black} \mu(w, x) - \color[RGB]{239,71,111}\left\{ \color{black} \tilde\mu(w,x) + t m(w,x) \color[RGB]{239,71,111} \right\} \color{black} \color{gray} \right\} \color{black}^2 \\ &\overset{\texttip{\small{\unicode{x2753}}}{via the chain rule}}{=} \sum_{wx} \gamma(w,x) m_{wx} \times 2 \ \color{gray}\left\{ \color{black} \mu(w, x) - \tilde\mu(w,x) \color{gray} \right\} \color{black} \ \frac{d}{dt}\mid_{t=0} \color{gray}\left\{ \color{black} \mu(w, x) - \color[RGB]{239,71,111}\left\{ \color{black} \tilde\mu(w,x) + t m(w,x) \color[RGB]{239,71,111} \right\} \color{black} \color{gray} \right\} \color{black} \\ &\overset{\texttip{\small{\unicode{x2753}}}{differentiating the last factor}}{=} \sum_{wx} \gamma(w,x) m_{wx} \times 2 \ \color{gray}\left\{ \color{black} \mu(w, x) - \tilde\mu(w,x) \color{gray} \right\} \color{black} \times -m(w,x) \\ &\overset{\texttip{\small{\unicode{x2753}}}{pulling constants out of the sum}}{=} -2 \sum_{wx} \gamma(w,x) m_{wx} \color{gray}\left\{ \color{black} \mu(w, x) - \tilde\mu(w,x) \color{gray} \right\} \color{black} \ m(w,x) \\ &\overset{\texttip{\small{\unicode{x2753}}}{observing that $\gamma(w,x)m_{wx} = \textcolor[RGB]{248,118,109}{m_{0x}}$. That's why we chose $\gamma$ as we did.}}{=} -2 \sum_{wx} \textcolor[RGB]{248,118,109}{m_{0x}} \color{gray}\left\{ \color{black} \mu(w, x) - \tilde\mu(w,x) \color{gray} \right\} \color{black} m(w,x). \end{aligned} \]
We’ll plug in group indicators \(1_{=0}(w)\) and \(1_{=x}(x)\) to this orthogonality condition \[ \begin{aligned} 0 &= \sum_{w \in 0,1} \sum_{j:w_j=0} \color{gray}\left\{ \color{black} \mu(w, x_j) - \tilde\mu(w,x_j) \color{gray} \right\} \color{black} m(w,x_j) \qqtext{ for all } m \in \mathcal{M}. \end{aligned} \]
To show that the matched and mismatched error terms in this error decomposition are zero.
\[ \Delta_0-\tilde\Delta_0 = \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}} \color{gray}\left\{ \color{black} \textcolor[RGB]{0,191,196}{\mu(1,x_j) - \tilde\mu(1,x_j)} \color{gray} \right\} \color{black} - \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}}\color{gray}\left\{ \color{black} \textcolor[RGB]{248,118,109}{\mu(0,x_j)} - \textcolor[RGB]{248,118,109}{\tilde\mu(0,x_j)} \color{gray} \right\} \color{black}. \]
The Matched Term \[ \begin{aligned} 0 &= \sum_{w \in 0,1} \sum_{j:w_j=0} \color{gray}\left\{ \color{black} \mu(w, x_j) - \tilde\mu(w,x_j) \color{gray} \right\} \color{black} \textcolor[RGB]{248,118,109}{1_{=0}(w)} \\ &\overset{\texttip{\small{\unicode{x2753}}}{This is a version of the 'indicator trick'. We drop terms where $w=1$ from our sum because $1_{=0}(1)=0$.}}{=} \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}} \color{gray}\left\{ \color{black} \textcolor[RGB]{248,118,109}{\mu(0, x_j)} - \textcolor[RGB]{248,118,109}{\tilde\mu(0,x_j)} \color{gray} \right\} \color{black} \end{aligned} \]
The Mismatched Term \[ \begin{aligned} 0 &= \sum_{w \in 0,1} \sum_{j:w_j=0} \color{gray}\left\{ \color{black} \mu(w, x_j) - \tilde\mu(w,x_j) \color{gray} \right\} \color{black} 1_{=1}(w) \\ &\overset{\texttip{\small{\unicode{x2753}}}{Via the indicator trick, as in the matched term.}}{=} \textcolor[RGB]{248,118,109}{\frac{1}{m_0} \sum_{j:w_j=0}} \color{gray}\left\{ \color{black} \textcolor[RGB]{0,191,196}{\mu(1, x_j)} - \textcolor[RGB]{0,191,196}{\tilde\mu(1,x_j)} \color{gray} \right\} \color{black} \end{aligned} \]
The raw difference is bigger. In comparison with the adjusted one, the average of \(\textcolor[RGB]{0,191,196}{\mu(1,x)}\) is taken over a covariate distribution that’s shifted to the right, where it’s bigger.
They’re the same. When we use the horizontal lines model, \(\tilde\mu(1,x)\) and \(\tilde\mu(0,x)\) are constant and equal to the within-group means. So in the first term in \(\tilde\Delta_{0}\), we’re taking the average of a constant—and that constant is the first term of \(\Delta_\text{raw}\). And the same happens for the second term.
It fits the green dots on the right well because that’s where most of the green dots are. But we’re averaging it over the distribution of the red dots, which are on the left.
\[ \textcolor[RGB]{192,192,192}{ \frac{m_0}{m_1} \times \frac{p_{x \mid 0}}{p_{x \mid 1}} = } \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{\textcolor[RGB]{0,191,196}{m_{1x}}} = a+bx \]
\[ \begin{aligned} 0 &= \sum_{j=1}^n \color{gray}\left\{ \color{black} \mu(w_j, x_j) - \tilde\mu(w_j,x_j) \color{gray} \right\} \color{black} m(w,x) \qqtext{ for any } &&m \in \{ a(w)+b(w)x \} \qqtext{ and therefore for } && m(w,x) = \begin{cases} 0 & \text{if } w=0 \\ a+bx = \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{\textcolor[RGB]{0,191,196}{m_{1x}}} & \text{if } w=1 \end{cases}. \end{aligned} \]
Simplifying, using the indicator trick, we get…
\[ \begin{aligned} 0 &= \sum_{j:w_j=1} \color{gray}\left\{ \color{black} \mu(1, x_j) - \tilde\mu(1,x_j) \color{gray} \right\} \color{black} \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{\textcolor[RGB]{0,191,196}{m_{1x}}} \\ &= \sum_x \textcolor[RGB]{0,191,196}{m_{1x}} \color{gray}\left\{ \color{black} \textcolor[RGB]{0,191,196}{\mu(1, x) - \tilde\mu(1,x)} \color{gray} \right\} \color{black} \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{\textcolor[RGB]{0,191,196}{m_{1x}}} \\ &= \sum_x \textcolor[RGB]{248,118,109}{m_{0x}} \color{gray}\left\{ \color{black} \textcolor[RGB]{0,191,196}{\mu(1, x) - \tilde\mu(1,x)} \color{gray} \right\} \color{black}. \end{aligned} \]
\[ \gamma_{\pm}(w,x) = \begin{cases} 1 & \text{if } w=0 \\ -(a+bx)=-\frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{\textcolor[RGB]{0,191,196}{m_{1x}}} & \text{if } w=1 \end{cases} \]
\[ \begin{aligned} \hat\Delta_0 &= \frac{1}{m_0} \sum_{j:w_j=0} \qty{ \hat\mu(1,x_j) - \hat\mu(0,x_j) } \\ & \qfor \hat\mu = \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \frac{1}{n}\sum_{i=1}^n \gamma(W_i,X_i) \color{gray}\left\{ \color{black} Y_i - m(W_i,X_i) \color{gray} \right\} \color{black}^2 \\ & \qand \gamma(w,x) = \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{m_{wx}} = \begin{cases} 1 & \text{if } w=0 \\ \frac{\textcolor[RGB]{248,118,109}{m_{0x}}}{\textcolor[RGB]{0,191,196}{m_{1x}}} & \text{if } w=1 \end{cases} \end{aligned} \]
\[ \begin{aligned} \hat\Delta_0 &= \frac{1}{N_0} \sum_{i:W_i=0} \qty{ \hat\mu(1,X_i) - \hat\mu(0,X_i) } \\ & \qfor \hat\mu = \mathop{\mathrm{argmin}}_{m \in \mathcal{M}} \frac{1}{n}\sum_{i=1}^n \gamma(W_i,X_i) \color{gray}\left\{ \color{black} Y_i - m(W_i,X_i) \color{gray} \right\} \color{black}^2 \\ & \qand \gamma(w,x) = \frac{\textcolor[RGB]{248,118,109}{N_{0x}}}{N_{wx}} = \begin{cases} 1 & \text{if } w=0 \\ \frac{\textcolor[RGB]{248,118,109}{N_{0x}}}{\textcolor[RGB]{0,191,196}{N_{1x}}} & \text{if } w=1 \end{cases} \end{aligned} \]
\[ \mathcal{M}=\{m(w,x)=a(w) \qqtext{ for functions } a\} \tag{2}\]
\[ \mathcal{M}=\{m(w,x)=a(w)+bx \qqtext{ for functions } a,b\} \tag{3}\]
\[ \mathcal{M}=\{m(w,x)=a(w)+b(w)x \qqtext{ for functions } a,b\} \tag{4}\]
\[ \mathcal{M}=\{m(w,x)=a(w)+b(x) \qqtext{ for functions } a,b\} \tag{5}\]
Here, for the sake of clarity, we’re talking about \(\hat\Delta_0\) defined in Slide 5.1.
Here, for the sake of clarity, we’re talking about \(\hat\Delta_0\) defined in Slide 5.2.
We just did this, but go ahead and copy it out here for reference.
Hint. For \(\Delta_0\), we wanted to make the distribution of the pretend (i.e. weighted) green dots like that of the red dots. Now the situation is reversed.
Hint. Now we’re averaging over the distribution of all dots. How do we duplicate the red dots to get that distribution? What about the green dots?
You’ll prove this for homework.