Skip to content

Commit

Permalink
S3: Fix recursion for buckets in different partition (#7731)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored May 30, 2024
1 parent d3dbf15 commit 4c098ae
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
10 changes: 5 additions & 5 deletions moto/s3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def __init__(self, name: str, account_id: str, region_name: str):
self.default_lock_days: Optional[int] = 0
self.default_lock_years: Optional[int] = 0
self.ownership_rule: Optional[Dict[str, Any]] = None
s3_backends.bucket_accounts[name] = account_id
s3_backends.bucket_accounts[name] = (self.partition, account_id)

@property
def location(self) -> str:
Expand Down Expand Up @@ -1852,8 +1852,8 @@ def get_bucket(self, bucket_name: str) -> FakeBucket:
if bucket_name in s3_backends.bucket_accounts:
if not s3_allow_crossdomain_access():
raise AccessDeniedByLock
account_id = s3_backends.bucket_accounts[bucket_name]
return s3_backends[account_id][self.partition].get_bucket(bucket_name)
(partition, account_id) = s3_backends.bucket_accounts[bucket_name]
return s3_backends[account_id][partition].get_bucket(bucket_name)

raise MissingBucket(bucket=bucket_name)

Expand Down Expand Up @@ -3028,9 +3028,9 @@ def __init__(
):
super().__init__(backend, service_name, use_boto3_regions, additional_regions)

# Maps bucket names to account IDs. This is used to locate the exact S3Backend
# Maps bucket names to (partition, account IDs). This is used to locate the exact S3Backend
# holding the bucket and to maintain the common bucket namespace.
self.bucket_accounts: Dict[str, str] = {}
self.bucket_accounts: Dict[str, Tuple[str, str]] = {}


s3_backends = S3BackendDict(
Expand Down
37 changes: 33 additions & 4 deletions tests/test_ec2/test_flow_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@


@mock_aws
def test_create_flow_logs_s3():
s3 = boto3.resource("s3", region_name="us-west-1")
client = boto3.client("ec2", region_name="us-west-1")
@pytest.mark.parametrize(
"region,partition", [("us-west-2", "aws"), ("cn-north-1", "aws-cn")]
)
def test_create_flow_logs_s3(region, partition):
s3 = boto3.resource("s3", region_name=region)
client = boto3.client("ec2", region_name=region)

vpc = client.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]

bucket_name = str(uuid4())
bucket = s3.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": "us-west-1"},
CreateBucketConfiguration={"LocationConstraint": region},
)

with pytest.raises(ClientError) as ex:
Expand Down Expand Up @@ -114,6 +117,32 @@ def test_create_multiple_flow_logs_s3():
assert flow_log_1["LogDestination"] != flow_log_2["LogDestination"]


@mock_aws
def test_create_flow_logs_s3__bucket_in_different_partition():
s3 = boto3.resource("s3", region_name="cn-north-1")
client = boto3.client("ec2", region_name="us-west-1")

vpc = client.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]

bucket_name = str(uuid4())
bucket = s3.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": "cn-north-1"},
)

response = client.create_flow_logs(
ResourceType="VPC",
ResourceIds=[vpc["VpcId"]],
TrafficType="ALL",
LogDestinationType="s3",
LogDestination="arn:aws:s3:::" + bucket.name,
)["FlowLogIds"]
assert len(response) == 1

flow_logs = client.describe_flow_logs(FlowLogIds=[response[0]])["FlowLogs"]
assert len(flow_logs) == 1


@mock_aws
def test_create_flow_logs_cloud_watch():
client = boto3.client("ec2", region_name="us-west-1")
Expand Down

0 comments on commit 4c098ae

Please sign in to comment.