Skip to content

Commit

Permalink
[Fix UT] fused_weight_only_linear_pass unittest modify (#59651)
Browse files Browse the repository at this point in the history
* unittest fix

* code style

* code style
  • Loading branch information
bukejiyu authored Dec 4, 2023
1 parent 36e262a commit 57a326d
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions test/ir/pir/fused_pass/test_fused_weight_only_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ def get_cuda_version():
"weight_only_linear requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class TestFusedWeightOnlyLinearPass_Fp32(PassTest):
def is_program_valid(self, program):
return True

def build_ir_progam(self):
pir_program = None
with paddle.pir_utils.IrGuard():
self.pir_program = paddle.static.Program()
with paddle.pir.core.program_guard(self.pir_program):
pir_program = paddle.static.Program()
with paddle.pir.core.program_guard(pir_program):
x = paddle.static.data(
name='x', shape=[3, 64, 64], dtype=self.dtype
)
Expand All @@ -75,11 +79,14 @@ def build_ir_progam(self):
"pd_op.matmul": 0,
"pd_op.add": 0,
}
return pir_program

def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float32'
self.build_ir_progam()

def sample_program(self):
yield self.build_ir_progam(), False

def test_check_output(self):
self.check_pass_correct()
Expand All @@ -89,7 +96,6 @@ class TestFusedWeightOnlyLinearPass_Fp16(TestFusedWeightOnlyLinearPass_Fp32):
def setUp(self):
self.place_runtime = "gpu"
self.dtype = 'float16'
self.build_ir_progam()


if __name__ == "__main__":
Expand Down

0 comments on commit 57a326d

Please sign in to comment.