Skip to content

Commit

Permalink
Merge pull request #62 from masa-su/add_log_prob
Browse files Browse the repository at this point in the history
Add log_prob
  • Loading branch information
masa-su authored Apr 10, 2019
2 parents 796f0ac + 44319d3 commit e27a4e1
Show file tree
Hide file tree
Showing 33 changed files with 699 additions and 1,210 deletions.
37 changes: 25 additions & 12 deletions examples/cvae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
"Distributions (for training): \n",
" q(z|x,y), p(x|z,y) \n",
"Loss function: \n",
" mean(-E_q(z|x,y)[log p(x|z,y)] + KL[q(z|x,y)||p_prior(z)]) \n",
" mean(-(E_q(z|x,y)[log p(x|z,y)]) + KL[q(z|x,y)||p_prior(z)]) \n",
"Optimizer: \n",
" Adam (\n",
" Parameter Group 0\n",
Expand Down Expand Up @@ -285,7 +285,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 101.77it/s]"
"100%|██████████| 469/469 [00:04<00:00, 95.04it/s]"
]
},
{
Expand Down Expand Up @@ -314,7 +314,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 110.71it/s]"
"100%|██████████| 469/469 [00:04<00:00, 100.45it/s]"
]
},
{
Expand Down Expand Up @@ -343,7 +343,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 104.91it/s]"
"100%|██████████| 469/469 [00:04<00:00, 98.34it/s]"
]
},
{
Expand Down Expand Up @@ -372,7 +372,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 97.32it/s]"
"100%|██████████| 469/469 [00:04<00:00, 101.22it/s]"
]
},
{
Expand Down Expand Up @@ -401,7 +401,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 106.88it/s]"
"100%|██████████| 469/469 [00:04<00:00, 99.79it/s]"
]
},
{
Expand Down Expand Up @@ -430,7 +430,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 109.72it/s]"
"100%|██████████| 469/469 [00:04<00:00, 109.00it/s]"
]
},
{
Expand Down Expand Up @@ -459,7 +459,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 96.55it/s]"
"100%|██████████| 469/469 [00:04<00:00, 97.86it/s]"
]
},
{
Expand Down Expand Up @@ -488,7 +488,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 101.35it/s]"
"100%|██████████| 469/469 [00:04<00:00, 96.30it/s]"
]
},
{
Expand Down Expand Up @@ -517,7 +517,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 99.68it/s] "
"100%|██████████| 469/469 [00:03<00:00, 124.22it/s]"
]
},
{
Expand Down Expand Up @@ -546,14 +546,27 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 469/469 [00:04<00:00, 112.04it/s]\n"
"100%|██████████| 469/469 [00:03<00:00, 126.78it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 10 Train loss: 99.1070\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 10 Train loss: 99.1070\n",
"Test loss: 99.9338\n"
]
}
Expand Down
30 changes: 16 additions & 14 deletions examples/distributions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fac2fc9b8b0>"
"<torch._C.Generator at 0x7fdf8eaad8d0>"
]
},
"execution_count": 1,
Expand Down Expand Up @@ -315,12 +315,14 @@
{
"data": {
"text/plain": [
"{'x': tensor([[-1.1235, -1.1559, 0.4218, 0.8778, -0.1497, 0.2739, 1.1814, -0.7278,\n",
" 0.2572, 0.1075, -0.7142, -0.7021, 0.6641, -1.1700, -1.8278, -0.9027,\n",
" 0.6691, 0.2645, 0.2566, -0.1142],\n",
" [-0.2431, -0.5863, -0.0452, 2.1263, 0.9091, 0.5982, -0.9394, 0.3520,\n",
" -0.7051, 1.8862, 0.4602, -0.2422, -0.6304, 0.8388, 0.8246, 1.1748,\n",
" 0.3473, -0.8007, 0.2327, 0.3098]])}"
"{'x': tensor([[-1.1551e+00, -1.2686e+00, 4.2959e-01, 8.6341e-01, -1.2102e-02,\n",
" 3.1782e-01, 1.2648e+00, -8.4481e-01, 4.3645e-01, 1.8594e-01,\n",
" -6.2100e-01, -5.1982e-01, 6.1928e-01, -1.1163e+00, -1.8660e+00,\n",
" -9.2985e-01, 5.8562e-01, 2.9533e-01, 2.9360e-01, 1.1179e-03],\n",
" [-2.5069e-01, -6.9296e-01, 1.7690e-02, 2.1784e+00, 8.5724e-01,\n",
" 5.4175e-01, -8.0418e-01, 3.2449e-01, -7.0768e-01, 1.9246e+00,\n",
" 6.1470e-01, -2.2973e-01, -6.1405e-01, 7.4305e-01, 6.9249e-01,\n",
" 1.1372e+00, 3.4499e-01, -6.6509e-01, 1.5735e-01, 2.6749e-01]])}"
]
},
"execution_count": 12,
Expand All @@ -347,26 +349,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([-26.2268, -23.8622], grad_fn=<SumBackward2>)\n",
"tensor([-18.9722, -19.5073], grad_fn=<SumBackward2>)\n",
"tensor([-45.1698, -41.0407], grad_fn=<AddBackward0>)\n",
"tensor([-26.2297, -23.6749], grad_fn=<SumBackward2>)\n",
"tensor([-19.4051, -19.5314], grad_fn=<SumBackward2>)\n",
"tensor([-45.6209, -40.9298], grad_fn=<AddBackward0>)\n",
"tensor([-155.4684, -163.4326, -150.2627, -150.2103, -159.1462, -163.7559,\n",
" -168.1021, -162.1275, -160.1595, -142.4833], grad_fn=<AddBackward0>)\n"
]
}
],
"source": [
"outputs = p1.sample({\"y\":y, \"a\":a})\n",
"print(p1.log_likelihood(outputs))\n",
"print(p1.log_prob().eval(outputs))\n",
"\n",
"outputs = p2.sample({\"x\":x, \"y\":y})\n",
"print(p2.log_likelihood(outputs))\n",
"print(p2.log_prob().eval(outputs))\n",
"\n",
"outputs = p3.sample({\"y\":y, \"a\":a})\n",
"print(p3.log_likelihood(outputs))\n",
"print(p3.log_prob().eval(outputs))\n",
"\n",
"outputs = p_all.sample(batch_size=10)\n",
"print(p_all.log_likelihood(outputs))"
"print(p_all.log_prob().eval(outputs))"
]
},
{
Expand Down
Loading

0 comments on commit e27a4e1

Please sign in to comment.