Skip to content

Commit

Permalink
Update latest version of site
Browse files Browse the repository at this point in the history
  • Loading branch information
docusaurus-bot committed Jan 24, 2025
1 parent c64b2d7 commit dedf2f6
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 82 deletions.
35 changes: 34 additions & 1 deletion v/latest/api/_modules/botorch/models/gpytorch.html
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ <h1>Source code for botorch.models.gpytorch</h1><div class="highlight"><pre>
<span class="n">interleaved</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">mvns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_broadcast_mvns</span><span class="p">(</span><span class="n">mvns</span><span class="o">=</span><span class="n">mvns</span><span class="p">)</span>
<span class="n">mvn</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">mvns</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mvns</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
Expand All @@ -791,7 +792,39 @@ <h1>Source code for botorch.models.gpytorch</h1><div class="highlight"><pre>
<a class="viewcode-back" href="../../../models.html#botorch.models.gpytorch.ModelListGPyTorchModel.condition_on_observations">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">condition_on_observations</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Y</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Model</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span></div>
</div>


<span class="k">def</span><span class="w"> </span><span class="nf">_broadcast_mvns</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mvns</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">MultivariateNormal</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">MultivariateNormal</span><span class="p">:</span>
<span class="w"> </span><span class="sd">"""Broadcasts the batch shapes of the given MultivariateNormals.</span>

<span class="sd"> The MVNs will have a batch shape of `input_batch_shape x model_batch_shape`.</span>
<span class="sd"> If the model batch shapes are broadcastable, we will broadcast the mvns to</span>
<span class="sd"> a batch shape of `input_batch_shape x self.batch_shape`.</span>

<span class="sd"> Args:</span>
<span class="sd"> mvns: A list of MultivariateNormals.</span>

<span class="sd"> Returns:</span>
<span class="sd"> A list of MultivariateNormals with broadcasted batch shapes.</span>
<span class="sd"> """</span>
<span class="n">mvn_batch_shapes</span> <span class="o">=</span> <span class="p">{</span><span class="n">mvn</span><span class="o">.</span><span class="n">batch_shape</span> <span class="k">for</span> <span class="n">mvn</span> <span class="ow">in</span> <span class="n">mvns</span><span class="p">}</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mvn_batch_shapes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># All MVNs have the same batch shape. We can return as is.</span>
<span class="k">return</span> <span class="n">mvns</span>
<span class="c1"># This call will error out if they're not broadcastable.</span>
<span class="c1"># If they're broadcastable, it'll log a warning.</span>
<span class="n">target_model_shape</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_shape</span>
<span class="n">max_batch</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">mvn_batch_shapes</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="nb">len</span><span class="p">)</span>
<span class="n">max_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">max_batch</span><span class="p">)</span>
<span class="n">input_batch_len</span> <span class="o">=</span> <span class="n">max_len</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">target_model_shape</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">mvns</span><span class="p">)):</span> <span class="c1"># Loop over index since we modify contents.</span>
<span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">batch_shape</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">max_len</span><span class="p">:</span>
<span class="c1"># MVN is missing batch dimensions. Unsqueeze as needed.</span>
<span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">input_batch_len</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">batch_shape</span> <span class="o">!=</span> <span class="n">max_batch</span><span class="p">:</span>
<span class="c1"># Expand to match the batch shapes.</span>
<span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">max_batch</span><span class="p">)</span>
<span class="k">return</span> <span class="n">mvns</span></div>



Expand Down
35 changes: 34 additions & 1 deletion v/latest/api/_modules/botorch/models/gpytorch/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ <h1>Source code for botorch.models.gpytorch</h1><div class="highlight"><pre>
<span class="n">interleaved</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">mvns</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_broadcast_mvns</span><span class="p">(</span><span class="n">mvns</span><span class="o">=</span><span class="n">mvns</span><span class="p">)</span>
<span class="n">mvn</span> <span class="o">=</span> <span class="p">(</span>
<span class="n">mvns</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mvns</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span>
Expand All @@ -791,7 +792,39 @@ <h1>Source code for botorch.models.gpytorch</h1><div class="highlight"><pre>
<a class="viewcode-back" href="../../../models.html#botorch.models.gpytorch.ModelListGPyTorchModel.condition_on_observations">[docs]</a>
<span class="k">def</span><span class="w"> </span><span class="nf">condition_on_observations</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">X</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">Y</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Model</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span></div>
</div>


<span class="k">def</span><span class="w"> </span><span class="nf">_broadcast_mvns</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mvns</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">MultivariateNormal</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">MultivariateNormal</span><span class="p">:</span>
<span class="w"> </span><span class="sd">"""Broadcasts the batch shapes of the given MultivariateNormals.</span>

<span class="sd"> The MVNs will have a batch shape of `input_batch_shape x model_batch_shape`.</span>
<span class="sd"> If the model batch shapes are broadcastable, we will broadcast the mvns to</span>
<span class="sd"> a batch shape of `input_batch_shape x self.batch_shape`.</span>

<span class="sd"> Args:</span>
<span class="sd"> mvns: A list of MultivariateNormals.</span>

<span class="sd"> Returns:</span>
<span class="sd"> A list of MultivariateNormals with broadcasted batch shapes.</span>
<span class="sd"> """</span>
<span class="n">mvn_batch_shapes</span> <span class="o">=</span> <span class="p">{</span><span class="n">mvn</span><span class="o">.</span><span class="n">batch_shape</span> <span class="k">for</span> <span class="n">mvn</span> <span class="ow">in</span> <span class="n">mvns</span><span class="p">}</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">mvn_batch_shapes</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
<span class="c1"># All MVNs have the same batch shape. We can return as is.</span>
<span class="k">return</span> <span class="n">mvns</span>
<span class="c1"># This call will error out if they're not broadcastable.</span>
<span class="c1"># If they're broadcastable, it'll log a warning.</span>
<span class="n">target_model_shape</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_shape</span>
<span class="n">max_batch</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">mvn_batch_shapes</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="nb">len</span><span class="p">)</span>
<span class="n">max_len</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">max_batch</span><span class="p">)</span>
<span class="n">input_batch_len</span> <span class="o">=</span> <span class="n">max_len</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">target_model_shape</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">mvns</span><span class="p">)):</span> <span class="c1"># Loop over index since we modify contents.</span>
<span class="k">while</span> <span class="nb">len</span><span class="p">(</span><span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">batch_shape</span><span class="p">)</span> <span class="o">&lt;</span> <span class="n">max_len</span><span class="p">:</span>
<span class="c1"># MVN is missing batch dimensions. Unsqueeze as needed.</span>
<span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="n">input_batch_len</span><span class="p">)</span>
<span class="k">if</span> <span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">batch_shape</span> <span class="o">!=</span> <span class="n">max_batch</span><span class="p">:</span>
<span class="c1"># Expand to match the batch shapes.</span>
<span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">mvns</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">max_batch</span><span class="p">)</span>
<span class="k">return</span> <span class="n">mvns</span></div>



Expand Down
Loading

0 comments on commit dedf2f6

Please sign in to comment.