From cbde7f983cdd2ec5736281a79257f759349b41b4 Mon Sep 17 00:00:00 2001
From: Samuliak <samuliak77@gmail.com>
Date: Tue, 15 Oct 2024 19:48:32 +0200
Subject: [PATCH] force compile shaders if needed

---
 .../Renderer/Metal/MetalPipelineCompiler.cpp  | 55 +++++++++++++------
 .../Renderer/Metal/MetalPipelineCompiler.h    |  6 +-
 .../HW/Latte/Renderer/Metal/MetalRenderer.cpp | 28 ++++------
 3 files changed, 51 insertions(+), 38 deletions(-)

diff --git a/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCompiler.cpp b/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCompiler.cpp
index ee01f04bb..54aa83b1b 100644
--- a/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCompiler.cpp
+++ b/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCompiler.cpp
@@ -309,25 +309,22 @@ MetalPipelineCompiler::~MetalPipelineCompiler()
 
 void MetalPipelineCompiler::InitFromState(const LatteFetchShader* fetchShader, const LatteDecompilerShader* vertexShader, const LatteDecompilerShader* geometryShader, const LatteDecompilerShader* pixelShader, const MetalAttachmentsInfo& lastUsedAttachmentsInfo, const MetalAttachmentsInfo& activeAttachmentsInfo, const LatteContextRegister& lcr)
 {
-    // Shaders
-    m_vertexShader = static_cast<const RendererShaderMtl*>(vertexShader->shader);
-    if (geometryShader)
-    {
-        m_geometryShader = static_cast<const RendererShaderMtl*>(geometryShader->shader);
-    }
-    else
-    {
-        // If there is no geometry shader, it means that we are emulating rects
-        m_geometryShader = rectsEmulationGS_generate(m_mtlr, vertexShader, lcr);
-    }
-    m_pixelShader = static_cast<const RendererShaderMtl*>(pixelShader->shader);
-
     // Check if the pipeline uses a geometry shader
     const LattePrimitiveMode primitiveMode = static_cast<LattePrimitiveMode>(LatteGPUState.contextRegister[mmVGT_PRIMITIVE_TYPE]);
     bool isPrimitiveRect = (primitiveMode == Latte::LATTE_VGT_PRIMITIVE_TYPE::E_PRIMITIVE_TYPE::RECTS);
 
     m_usesGeometryShader = (geometryShader != nullptr || isPrimitiveRect);
 
+    // Shaders
+    m_vertexShaderMtl = static_cast<RendererShaderMtl*>(vertexShader->shader);
+    if (geometryShader)
+        m_geometryShaderMtl = static_cast<RendererShaderMtl*>(geometryShader->shader);
+    else if (isPrimitiveRect)
+        m_geometryShaderMtl = rectsEmulationGS_generate(m_mtlr, vertexShader, lcr);
+    else
+        m_geometryShaderMtl = nullptr;
+    m_pixelShaderMtl = static_cast<RendererShaderMtl*>(pixelShader->shader);
+
     if (m_usesGeometryShader)
         InitFromStateMesh(fetchShader, lastUsedAttachmentsInfo, activeAttachmentsInfo, lcr);
     else
@@ -336,6 +333,28 @@ void MetalPipelineCompiler::InitFromState(const LatteFetchShader* fetchShader, c
 
 MTL::RenderPipelineState* MetalPipelineCompiler::Compile(bool forceCompile, bool isRenderThread, bool showInOverlay)
 {
+    if (forceCompile)
+	{
+		// if some shader stages are not compiled yet, compile them now
+		if (m_vertexShaderMtl && !m_vertexShaderMtl->IsCompiled())
+			m_vertexShaderMtl->PreponeCompilation(isRenderThread);
+		if (m_geometryShaderMtl && !m_geometryShaderMtl->IsCompiled())
+			m_geometryShaderMtl->PreponeCompilation(isRenderThread);
+		if (m_pixelShaderMtl && !m_pixelShaderMtl->IsCompiled())
+			m_pixelShaderMtl->PreponeCompilation(isRenderThread);
+	}
+	else
+	{
+	    // fail early if some shader stages are not compiled
+		if (m_vertexShaderMtl && !m_vertexShaderMtl->IsCompiled())
+			return nullptr;
+		if (m_geometryShaderMtl && !m_geometryShaderMtl->IsCompiled())
+			return nullptr;
+		if (m_pixelShaderMtl && !m_pixelShaderMtl->IsCompiled())
+			return nullptr;
+	}
+
+	// Compile
     MTL::RenderPipelineState* pipeline = nullptr;
     NS::Error* error = nullptr;
 
@@ -345,9 +364,9 @@ MTL::RenderPipelineState* MetalPipelineCompiler::Compile(bool forceCompile, bool
         auto desc = static_cast<MTL::MeshRenderPipelineDescriptor*>(m_pipelineDescriptor);
 
         // Shaders
-        desc->setObjectFunction(m_vertexShader->GetFunction());
-        desc->setMeshFunction(m_geometryShader->GetFunction());
-        desc->setFragmentFunction(m_pixelShader->GetFunction());
+        desc->setObjectFunction(m_vertexShaderMtl->GetFunction());
+        desc->setMeshFunction(m_geometryShaderMtl->GetFunction());
+        desc->setFragmentFunction(m_pixelShaderMtl->GetFunction());
 
         NS::Error* error = nullptr;
 #ifdef CEMU_DEBUG_ASSERT
@@ -360,8 +379,8 @@ MTL::RenderPipelineState* MetalPipelineCompiler::Compile(bool forceCompile, bool
         auto desc = static_cast<MTL::RenderPipelineDescriptor*>(m_pipelineDescriptor);
 
         // Shaders
-        desc->setVertexFunction(m_vertexShader->GetFunction());
-        desc->setFragmentFunction(m_pixelShader->GetFunction());
+        desc->setVertexFunction(m_vertexShaderMtl->GetFunction());
+        desc->setFragmentFunction(m_pixelShaderMtl->GetFunction());
 
         NS::Error* error = nullptr;
 #ifdef CEMU_DEBUG_ASSERT
diff --git a/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCompiler.h b/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCompiler.h
index 4f0febefb..f39b1fb5e 100644
--- a/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCompiler.h
+++ b/src/Cafe/HW/Latte/Renderer/Metal/MetalPipelineCompiler.h
@@ -18,9 +18,9 @@ class MetalPipelineCompiler
 private:
     class MetalRenderer* m_mtlr;
 
-    const class RendererShaderMtl* m_vertexShader;
-    const class RendererShaderMtl* m_geometryShader;
-    const class RendererShaderMtl* m_pixelShader;
+    class RendererShaderMtl* m_vertexShaderMtl;
+    class RendererShaderMtl* m_geometryShaderMtl;
+    class RendererShaderMtl* m_pixelShaderMtl;
     bool m_usesGeometryShader;
 
     /*
diff --git a/src/Cafe/HW/Latte/Renderer/Metal/MetalRenderer.cpp b/src/Cafe/HW/Latte/Renderer/Metal/MetalRenderer.cpp
index 17050326f..2b420e6e2 100644
--- a/src/Cafe/HW/Latte/Renderer/Metal/MetalRenderer.cpp
+++ b/src/Cafe/HW/Latte/Renderer/Metal/MetalRenderer.cpp
@@ -944,15 +944,9 @@ void MetalRenderer::draw_execute(uint32 baseVertex, uint32 baseInstance, uint32
 
     // Shaders
     LatteDecompilerShader* vertexShader = LatteSHRC_GetActiveVertexShader();
-    if (vertexShader && !vertexShader->shader->IsCompiled())
-        return;
     LatteDecompilerShader* geometryShader = LatteSHRC_GetActiveGeometryShader();
-    if (geometryShader && !geometryShader->shader->IsCompiled())
-        return;
     LatteDecompilerShader* pixelShader = LatteSHRC_GetActivePixelShader();
     const auto fetchShader = LatteSHRC_GetActiveFetchShader();
-    if (vertexShader && !pixelShader->shader->IsCompiled())
-        return;
 
     bool neverSkipAccurateBarrier = false;
 
@@ -1004,6 +998,17 @@ void MetalRenderer::draw_execute(uint32 baseVertex, uint32 baseInstance, uint32
 	// Render pass
 	auto renderCommandEncoder = GetRenderCommandEncoder();
 
+    // Render pipeline state
+    MTL::RenderPipelineState* renderPipelineState = m_pipelineCache->GetRenderPipelineState(fetchShader, vertexShader, geometryShader, pixelShader, m_state.m_lastUsedFBO.m_attachmentsInfo, m_state.m_activeFBO.m_attachmentsInfo, LatteGPUState.contextNew);
+    if (!renderPipelineState)
+        return;
+
+    if (renderPipelineState != encoderState.m_renderPipelineState)
+   	{
+        renderCommandEncoder->setRenderPipelineState(renderPipelineState);
+  		encoderState.m_renderPipelineState = renderPipelineState;
+   	}
+
 	// Depth stencil state
 
 	// Disable depth write when there is no depth attachment
@@ -1222,17 +1227,6 @@ void MetalRenderer::draw_execute(uint32 baseVertex, uint32 baseInstance, uint32
     //    renderCommandEncoder->memoryBarrier(barrierBuffers.data(), barrierBuffers.size(), MTL::RenderStageVertex, MTL::RenderStageVertex);
     //}
 
-	// Render pipeline state
-	MTL::RenderPipelineState* renderPipelineState = m_pipelineCache->GetRenderPipelineState(fetchShader, vertexShader, geometryShader, pixelShader, m_state.m_lastUsedFBO.m_attachmentsInfo, m_state.m_activeFBO.m_attachmentsInfo, LatteGPUState.contextNew);
-    if (!renderPipelineState)
-        return;
-
-	if (renderPipelineState != encoderState.m_renderPipelineState)
-   	{
-        renderCommandEncoder->setRenderPipelineState(renderPipelineState);
-  		encoderState.m_renderPipelineState = renderPipelineState;
-   	}
-
 	// Prepare streamout
 	m_state.m_streamoutState.verticesPerInstance = count;
 	LatteStreamout_PrepareDrawcall(count, instanceCount);