« home

MAF

Tags

machine learninggenerative modelingprobabilitystatisticsnormalizing flows

Illustration of the slow (sequential) forward pass of a Masked Autoregressive Flow (MAF) layer as introduced in arxiv:1705.07057. Inspired by https://blog.evjang.com/2018/01/nf2.html.


MAF

  Edit

Overleaf Logo Open in Overleaf

  Download

PNG PNG (HD) PDF SVG TeX

  Code

maf.tex (49 lines)

\documentclass[tikz]{standalone}

\usetikzlibrary{calc,positioning}

\begin{document}
\begin{tikzpicture}[
    thick, text centered,
    box/.style={draw, thin, minimum width=1cm},
    func/.style={circle, text=white},
    input/.style={draw=red, very thick},
  ]

  % x nodes
  \node[box, input, fill=blue!20] (x1) {$x_1$};
  \node[box, input, fill=blue!20, right of=x1] (x2) {$x_2$};
  \node[right of=x2] (xdots1) {\dots};
  \node[box, input, fill=blue!20, right of=xdots1] (xd) {$x_d$};
  \node[box, fill=green!60!black, text opacity=1, opacity=0.4, right=2 of xd] (xdp1) {$x_{d+1}$};
  \node[right of=xdp1] (xdots2) {\dots};
  \node[box, fill=green!60!black, text opacity=1, opacity=0.4, right of=xdots2] (xD) {$x_D$};

  % z nodes
  \node[box, fill=blue!20, below=3 of x1] (z1) {$z_1$};
  \node[box, fill=blue!20, right of=z1] (z2) {$z_2$};
  \node[right of=z2] (zdots1) {\dots};
  \node[box, fill=blue!20, right of=zdots1] (zd) {$z_d$};
  \node[box, input, fill=orange!40, right=2 of zd] (zdp1) {$z_{d+1}$};
  \node[right of=zdp1] (zdots2) {\dots};
  \node[box, fill=orange!40, right of=zdots2] (zD) {$z_D$};

  % z to x lines
  \draw[->] (zdp1) -- (xdp1);

  % scale and translate functions
  \node[func, font=\large, fill=teal, above right=0.1] (t) at ($(zd)!0.5!(xdp1)$) {$t$};
  \fill[teal, opacity=0.5] (x1.south west) -- (t.center) -- (xd.south east) -- (x1.south west);

  \node[func, font=\large, fill=orange, below left=0.1] (s) at ($(zd)!0.5!(xdp1)$) {$s$};
  \fill[orange, opacity=0.5] (x1.south west) -- (s.center) -- (xd.south east) -- (x1.south west);

  % feeding in s and t
  \node[func, inner sep=0, fill=orange] (odot1) at ($(zdp1)!0.4!(xdp1)$) {$\odot$};
  \node[func, inner sep=0, fill=teal] (oplus1) at ($(zdp1)!0.7!(xdp1)$) {$\oplus$};
  \draw[orange, ->] (s) to[bend right=5] (odot1);
  \draw[teal, ->] (t) to[bend right=5] (oplus1);

\end{tikzpicture}
\end{document}