diff --git a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/CachedJwtSource.java b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/CachedJwtSource.java index 32b04db4..b39149c7 100644 --- a/java-spiffe-core/src/main/java/io/spiffe/workloadapi/CachedJwtSource.java +++ b/java-spiffe-core/src/main/java/io/spiffe/workloadapi/CachedJwtSource.java @@ -270,12 +270,17 @@ private static Set getAudienceSet(String audience, String[] extraAudienc } private boolean isTokenPastHalfLifetime(JwtSvid jwtSvid) { - Instant now = clock.instant(); - Date halfLife = new Date(jwtSvid.getExpiry().getTime() - (jwtSvid.getExpiry().getTime() - jwtSvid.getIssuedAt().getTime()) / 2); - Instant halfLifeInstant = Instant.ofEpochMilli(halfLife.getTime()); - return now.isAfter(halfLifeInstant); - } + Object issuedAtClaim = jwtSvid.getClaims().get("iat"); + if (!(issuedAtClaim instanceof Date)) { + return true; + } + long expiryTime = jwtSvid.getExpiry().getTime(); + long issuedAtTime = ((Date) issuedAtClaim).getTime(); + long halfLifeTime = expiryTime - (expiryTime - issuedAtTime) / 2; + + return clock.instant().toEpochMilli() > halfLifeTime; + } private void init(final Duration timeout) throws TimeoutException { CountDownLatch done = new CountDownLatch(1); diff --git a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/CachedJwtSourceTest.java b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/CachedJwtSourceTest.java index 5893c767..72b1ac2c 100644 --- a/java-spiffe-core/src/test/java/io/spiffe/workloadapi/CachedJwtSourceTest.java +++ b/java-spiffe-core/src/test/java/io/spiffe/workloadapi/CachedJwtSourceTest.java @@ -1,6 +1,8 @@ package io.spiffe.workloadapi; import com.google.common.collect.Sets; +import com.nimbusds.jose.jwk.Curve; +import com.nimbusds.jwt.JWTClaimsSet; import io.spiffe.bundle.jwtbundle.JwtBundle; import io.spiffe.exception.BundleNotFoundException; import io.spiffe.exception.JwtSourceException; @@ -9,18 +11,21 @@ import io.spiffe.spiffeid.SpiffeId; import io.spiffe.spiffeid.TrustDomain; import io.spiffe.svid.jwtsvid.JwtSvid; +import io.spiffe.utils.TestUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; import java.io.IOException; +import java.security.KeyPair; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import java.util.ArrayList; import java.util.Collections; +import java.util.Date; import java.util.List; import java.util.Set; import java.util.concurrent.ExecutionException; @@ -175,6 +180,24 @@ void testFetchJwtSvidWithSubject_JwtSvidExpiredInCache() { } } + @Test + void testFetchJwtSvidWithSubject_cachedJwtSvidWithoutIssuedAt_refetchesWithoutThrowingNullPointerException() throws JwtSvidException { + Set audience = Collections.singleton(TEST_AUDIENCE); + Date expiration = Date.from(clock.instant().plus(JWT_TTL)); + JWTClaimsSet claims = TestUtils.buildJWTClaimSet(audience, TEST_SUBJECT.toString(), expiration); + KeyPair keyPair = TestUtils.generateECKeyPair(Curve.P_521); + JwtSvid svidWithoutIssuedAt = JwtSvid.parseInsecure(TestUtils.generateToken(claims, keyPair, "authority1"), audience); + + jwtSource.putCachedJwtSvidsForTest(TEST_SUBJECT, audience, Collections.singletonList(svidWithoutIssuedAt)); + int initialCallCount = workloadApiClient.getFetchJwtSvidCallCount(); + + JwtSvid svid = assertDoesNotThrow(() -> jwtSource.fetchJwtSvid(TEST_SUBJECT, TEST_AUDIENCE)); + + assertNotNull(svid); + assertEquals(TEST_SUBJECT, svid.getSpiffeId()); + assertEquals(initialCallCount + 1, workloadApiClient.getFetchJwtSvidCallCount()); + } + @Test void testFetchJwtSvidWithSubject_JwtSvidExpiredInCache_MultipleThreads() { // test fetchJwtSvid with several threads trying to read and write the cache