Skip to content

Commit 17da6df

Browse files
refactor(aws): enhance createAssumedRoleSession to accept account ID as parameter and improve error handling in Verify method
- Updated createAssumedRoleSession to take accountId as an argument, allowing for more flexible role assumption. - Modified Verify method to iterate over AccountIds, capturing errors for each account and returning the last encountered error if all attempts fail.
1 parent c53a8ef commit 17da6df

1 file changed

Lines changed: 17 additions & 15 deletions

File tree

pkg/providers/aws/aws.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ func New(block schema.OptionBlock) (*Provider, error) {
249249

250250
if err != nil && options.AssumeRoleName != "" && len(options.AccountIds) > 0 {
251251
// Base user doesn't have DescribeRegions permission, try with assumed role
252-
tempSession, err := createAssumedRoleSession(options, sess, config)
252+
tempSession, err := createAssumedRoleSession(options, sess, config, options.AccountIds[0])
253253
if err != nil {
254254
return nil, errors.Wrap(err, "could not create assumed role session")
255255
}
@@ -270,12 +270,9 @@ func New(block schema.OptionBlock) (*Provider, error) {
270270
return provider, nil
271271
}
272272

273-
func createAssumedRoleSession(options *ProviderOptions, sess *session.Session, config *aws.Config) (*session.Session, error) {
274-
if len(options.AccountIds) == 0 {
275-
return nil, errors.New("no account IDs provided for assume role")
276-
}
273+
func createAssumedRoleSession(options *ProviderOptions, sess *session.Session, config *aws.Config, accountId string) (*session.Session, error) {
277274
stsClient := sts.New(sess)
278-
roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", options.AccountIds[0], options.AssumeRoleName)
275+
roleArn := fmt.Sprintf("arn:aws:iam::%s:role/%s", accountId, options.AssumeRoleName)
279276

280277
roleInput := &sts.AssumeRoleInput{
281278
RoleArn: aws.String(roleArn),
@@ -524,16 +521,21 @@ func (p *Provider) Verify(ctx context.Context) error {
524521
}
525522

526523
if p.options.AssumeRoleName != "" && len(p.options.AccountIds) > 0 {
527-
tempSession, err := createAssumedRoleSession(p.options, p.session, p.session.Config)
528-
if err != nil {
529-
return err
530-
}
531-
p.initServices(tempSession)
532-
err = p.verify()
533-
if err != nil {
534-
return err
524+
var lastErr error
525+
for _, accountId := range p.options.AccountIds {
526+
tempSession, err := createAssumedRoleSession(p.options, p.session, p.session.Config, accountId)
527+
if err != nil {
528+
lastErr = err
529+
continue
530+
}
531+
p.initServices(tempSession)
532+
if err := p.verify(); err != nil {
533+
lastErr = err
534+
continue
535+
}
536+
return nil
535537
}
536-
return nil
538+
return errors.Wrap(lastErr, "failed to verify credentials across all accounts")
537539
}
538540
return err
539541
}

0 commit comments

Comments
 (0)