Skip to content

Commit

Permalink
chore(hsgp_nd): ruff lint and fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
brendancooley committed May 17, 2024
1 parent 8146ea0 commit 04f96f0
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions notebooks/source/hsgp_nd_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
" if noise is None:\n",
" return k\n",
" else:\n",
" return k + (noise ** 2 + jitter) * jnp.eye(k.shape[0])\n",
" return k + (noise**2 + jitter) * jnp.eye(k.shape[0])\n",
"\n",
"\n",
"def sample_grid_and_data(\n",
Expand Down Expand Up @@ -206,7 +206,7 @@
" alpha=post_alpha,\n",
" color=\"tab:blue\",\n",
" )\n",
" \n",
"\n",
" # plot the data points\n",
" if X is not None and y is not None:\n",
" color = (\n",
Expand All @@ -222,7 +222,7 @@
" s=point_size,\n",
" alpha=point_alpha,\n",
" )\n",
" \n",
"\n",
" # add confidence intervals at the boundaries\n",
" if xz_lines:\n",
" for line in xz_lines:\n",
Expand Down Expand Up @@ -316,7 +316,7 @@
" ax.plot(\n",
" X_grid, post_y[i, :], linewidth=1.0, alpha=post_alpha, color=\"tab:blue\"\n",
" )\n",
" \n",
"\n",
" # plot the data points\n",
" if X is not None and y is not None:\n",
" if test_ind is None:\n",
Expand All @@ -331,7 +331,7 @@
" ax.fill_between(\n",
" X_grid.squeeze(), ci[0], ci[1], color=\"tab:blue\", alpha=ci_alpha\n",
" )\n",
" \n",
"\n",
" # add the noiseless function\n",
" ax.plot(X_grid, y_grid, linewidth=1.0, alpha=1.0, color=\"tab:orange\")\n",
"\n",
Expand Down Expand Up @@ -517,9 +517,7 @@
" mcmc = fit_mcmc(seed, m.model)\n",
"else:\n",
" guide = AutoNormal(m.model, init_loc_fn=init_to_median(num_samples=25))\n",
" svi_res = fit_svi(\n",
" seed=seed, model=m.model, guide=guide\n",
" )"
" svi_res = fit_svi(seed=seed, model=m.model, guide=guide)"
]
},
{
Expand Down Expand Up @@ -999,9 +997,7 @@
"hsgp_m = HSGPModel(m=5, D=D, L=L * 2.5)\n",
"\n",
"if inference == \"mcmc\":\n",
" hsgp_mcmc = fit_mcmc(\n",
" seed, hsgp_m.model, X=X_tr, y=y_tr\n",
" )\n",
" hsgp_mcmc = fit_mcmc(seed, hsgp_m.model, X=X_tr, y=y_tr)\n",
"else:\n",
" hsgp_guide = AutoNormal(hsgp_m.model, init_loc_fn=init_to_median(num_samples=25))\n",
" hsgp_res = fit_svi(seed, hsgp_m.model, hsgp_guide, X=X_tr, y=y_tr, num_steps=10000)"
Expand Down

0 comments on commit 04f96f0

Please sign in to comment.