Skip to content

Commit

Permalink
Fix missing array signature for ActiveRecordRelation #create
Browse files Browse the repository at this point in the history
  • Loading branch information
bitwise-aiden committed Jul 16, 2024
1 parent ee9d273 commit 00d9b03
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 16 deletions.
61 changes: 45 additions & 16 deletions lib/tapioca/dsl/compilers/active_record_relations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def gather_constants
[:find_or_create_by, :find_or_create_by!, :find_or_initialize_by, :create_or_find_by, :create_or_find_by!],
T::Array[Symbol],
)
BUILDER_METHODS = T.let([:new, :build, :create, :create!], T::Array[Symbol])
BUILDER_METHODS = T.let([:new, :create, :create!, :build], T::Array[Symbol])
TO_ARRAY_METHODS = T.let([:to_ary, :to_a], T::Array[Symbol])

private
Expand Down Expand Up @@ -991,25 +991,54 @@ def create_common_methods
end

FIND_OR_CREATE_METHODS.each do |method_name|
block_type = "T.nilable(T.proc.params(object: #{constant_name}).void)"
create_common_method(
method_name,
parameters: [
create_param("attributes", type: "T.untyped"),
create_block_param("block", type: block_type),
],
return_type: constant_name,
# `T.untyped` matches `T::Array[T.untyped]` so the array signature
# must be defined first for Sorbet to pick it, if valid.
sigs = [
common_relation_methods_module.create_sig(
parameters: {
attributes: "T::Array[T.untyped]",
block: "T.nilable(T.proc.params(objects: #{constant_name}).void)",
},
return_type: "T::Array[#{constant_name}]",
),
common_relation_methods_module.create_sig(
parameters: {
attributes: "T.untyped",
block: "T.nilable(T.proc.params(object: #{constant_name}).void)",
},
return_type: constant_name,
),
]
common_relation_methods_module.create_method_with_sigs(
method_name.to_s,
sigs: sigs,
parameters: [RBI::ReqParam.new("attributes"), RBI::BlockParam.new("block")],
)
end

BUILDER_METHODS.each do |method_name|
create_common_method(
method_name,
parameters: [
create_opt_param("attributes", type: "T.untyped", default: "nil"),
create_block_param("block", type: "T.nilable(T.proc.params(object: #{constant_name}).void)"),
],
return_type: constant_name,
# `T.untyped` matches `T::Array[T.untyped]` so the array signature
# must be defined first for Sorbet to pick it, if valid.
sigs = [
common_relation_methods_module.create_sig(
parameters: {
attributes: "T::Array[T.untyped]",
block: "T.nilable(T.proc.params(objects: #{constant_name}).void)",
},
return_type: "T::Array[#{constant_name}]",
),
common_relation_methods_module.create_sig(
parameters: {
attributes: "T.untyped",
block: "T.nilable(T.proc.params(object: #{constant_name}).void)",
},
return_type: constant_name,
),
]
common_relation_methods_module.create_method_with_sigs(
method_name.to_s,
sigs: sigs,
parameters: [RBI::OptParam.new("attributes", "nil"), RBI::BlockParam.new("block")],
)
end
end
Expand Down
18 changes: 18 additions & 0 deletions spec/tapioca/dsl/compilers/active_record_relations_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def any?(&block); end
sig { params(column_name: T.any(String, Symbol)).returns(T.any(Integer, Float, BigDecimal)) }
def average(column_name); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def build(attributes = nil, &block); end
Expand All @@ -107,15 +108,19 @@ def calculate(operation, column_name); end
sig { params(column_name: NilClass, block: T.proc.params(object: ::Post).void).returns(Integer) }
def count(column_name = nil, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def create(attributes = nil, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def create!(attributes = nil, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def create_or_find_by(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def create_or_find_by!(attributes, &block); end
Expand Down Expand Up @@ -150,12 +155,15 @@ def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, o
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) }
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def find_or_create_by(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def find_or_create_by!(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def find_or_initialize_by(attributes, &block); end
Expand Down Expand Up @@ -224,6 +232,7 @@ def member?(record); end
sig { params(column_name: T.any(String, Symbol)).returns(T.untyped) }
def minimum(column_name); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def new(attributes = nil, &block); end
Expand Down Expand Up @@ -790,6 +799,7 @@ def any?(&block); end
sig { params(column_name: T.any(String, Symbol)).returns(T.any(Integer, Float, BigDecimal)) }
def average(column_name); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def build(attributes = nil, &block); end
Expand All @@ -800,15 +810,19 @@ def calculate(operation, column_name); end
sig { params(column_name: NilClass, block: T.proc.params(object: ::Post).void).returns(Integer) }
def count(column_name = nil, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def create(attributes = nil, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def create!(attributes = nil, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def create_or_find_by(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def create_or_find_by!(attributes, &block); end
Expand Down Expand Up @@ -847,12 +861,15 @@ def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, o
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) }
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def find_or_create_by(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def find_or_create_by!(attributes, &block); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def find_or_initialize_by(attributes, &block); end
Expand Down Expand Up @@ -921,6 +938,7 @@ def member?(record); end
sig { params(column_name: T.any(String, Symbol)).returns(T.untyped) }
def minimum(column_name); end
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(objects: ::Post).void)).returns(T::Array[::Post]) }
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
def new(attributes = nil, &block); end
Expand Down

0 comments on commit 00d9b03

Please sign in to comment.