Skip to content

Commit

Permalink
Fix missing array signature for ActiveRecordRelation #create
Browse files Browse the repository at this point in the history
  • Loading branch information
bitwise-aiden committed Jul 8, 2024
1 parent ee9d273 commit cc8cf58
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 9 deletions.
55 changes: 46 additions & 9 deletions lib/tapioca/dsl/compilers/active_record_relations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def gather_constants
[:find_or_create_by, :find_or_create_by!, :find_or_initialize_by, :create_or_find_by, :create_or_find_by!],
T::Array[Symbol],
)
BUILDER_METHODS = T.let([:new, :build, :create, :create!], T::Array[Symbol])
CREATE_METHODS = T.let([:create, :create!, :build], T::Array[Symbol])
BUILDER_METHODS = T.let([:new], T::Array[Symbol])
TO_ARRAY_METHODS = T.let([:to_ary, :to_a], T::Array[Symbol])

private
Expand Down Expand Up @@ -991,14 +992,26 @@ def create_common_methods
end

FIND_OR_CREATE_METHODS.each do |method_name|
block_type = "T.nilable(T.proc.params(object: #{constant_name}).void)"
create_common_method(
method_name,
parameters: [
create_param("attributes", type: "T.untyped"),
create_block_param("block", type: block_type),
],
return_type: constant_name,
sigs = [
common_relation_methods_module.create_sig(
parameters: {
attributes: "T.untyped",
block: "T.nilable(T.proc.params(object: #{constant_name}).void)",
},
return_type: constant_name,
),
common_relation_methods_module.create_sig(
parameters: {
attributes: "T::Array[T.untyped]",
block: "T.nilable(T.proc.params(objects: #{constant_name}).void)",
},
return_type: "T::Array[#{constant_name}]",
),
]
common_relation_methods_module.create_method_with_sigs(
method_name.to_s,
sigs: sigs,
parameters: [RBI::ReqParam.new("attributes"), RBI::BlockParam.new("block")],
)
end

Expand All @@ -1012,6 +1025,30 @@ def create_common_methods
return_type: constant_name,
)
end

CREATE_METHODS.each do |method_name|
sigs = [
common_relation_methods_module.create_sig(
parameters: {
attributes: "T.untyped",
block: "T.nilable(T.proc.params(object: #{constant_name}).void)",
},
return_type: constant_name,
),
common_relation_methods_module.create_sig(
parameters: {
attributes: "T::Array[T.untyped]",
block: "T.nilable(T.proc.params(objects: #{constant_name}).void)",
},
return_type: "T::Array[#{constant_name}]",
),
]
common_relation_methods_module.create_method_with_sigs(
method_name.to_s,
sigs: sigs,
parameters: [RBI::OptParam.new("attributes", "nil"), RBI::BlockParam.new("block")],
)
end
end

sig do
Expand Down
16 changes: 16 additions & 0 deletions spec/tapioca/dsl/compilers/active_record_relations_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def any?(&block); end
def average(column_name); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def build(attributes = nil, &block); end
sig { params(operation: Symbol, column_name: T.any(String, Symbol)).returns(T.any(Integer, Float, BigDecimal)) }
Expand All @@ -108,15 +109,19 @@ def calculate(operation, column_name); end
def count(column_name = nil, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def create(attributes = nil, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def create!(attributes = nil, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def create_or_find_by(attributes, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def create_or_find_by!(attributes, &block); end
sig { returns(T::Array[::Post]) }
Expand Down Expand Up @@ -151,12 +156,15 @@ def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, o
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def find_or_create_by(attributes, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def find_or_create_by!(attributes, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def find_or_initialize_by(attributes, &block); end
sig { params(signed_id: T.untyped, purpose: T.untyped).returns(T.nilable(::Post)) }
Expand Down Expand Up @@ -791,6 +799,7 @@ def any?(&block); end
def average(column_name); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def build(attributes = nil, &block); end
sig { params(operation: Symbol, column_name: T.any(String, Symbol)).returns(T.any(Integer, Float, BigDecimal)) }
Expand All @@ -801,15 +810,19 @@ def calculate(operation, column_name); end
def count(column_name = nil, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def create(attributes = nil, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def create!(attributes = nil, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def create_or_find_by(attributes, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def create_or_find_by!(attributes, &block); end
sig { returns(T::Array[::Post]) }
Expand Down Expand Up @@ -848,12 +861,15 @@ def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, o
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def find_or_create_by(attributes, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def find_or_create_by!(attributes, &block); end
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
def find_or_initialize_by(attributes, &block); end
sig { params(signed_id: T.untyped, purpose: T.untyped).returns(T.nilable(::Post)) }
Expand Down

0 comments on commit cc8cf58

Please sign in to comment.