Skip to content

Commit

Permalink
Updsate the zipfile format to prevent choking on json parsing (#4115)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
kbirk and github-actions[bot] authored Jul 12, 2024
1 parent 51ba624 commit 10a9866
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
import io.swagger.v3.oas.annotations.tags.Tag;
import io.swagger.v3.oas.annotations.tags.Tags;
import jakarta.transaction.Transactional;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
Expand Down Expand Up @@ -512,31 +508,13 @@ public ResponseEntity<byte[]> exportProject(@PathVariable("id") final UUID id) {
try {
final ProjectExport export = cloneService.exportProject(id);

final byte[] exportBytes = objectMapper.writeValueAsBytes(export);

try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ZipOutputStream zipOutputStream = new ZipOutputStream(byteArrayOutputStream)) {

final ZipEntry zipEntry = new ZipEntry("project.json");
zipOutputStream.putNextEntry(zipEntry);
zipOutputStream.write(exportBytes);
zipOutputStream.closeEntry();

zipOutputStream.finish();
zipOutputStream.close();

final HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.parseMediaType("application/zip"));
final String filename = "project-" + id + ".zip";
headers.setContentDispositionFormData(filename, filename);
headers.setCacheControl("must-revalidate, post-check=0, pre-check=0");

return new ResponseEntity<>(byteArrayOutputStream.toByteArray(), headers, HttpStatus.OK);
} catch (final IOException e) {
e.printStackTrace();
return new ResponseEntity<>(HttpStatus.INTERNAL_SERVER_ERROR);
}
final HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.parseMediaType("application/zip"));
final String filename = "project-" + id + ".zip";
headers.setContentDispositionFormData(filename, filename);
headers.setCacheControl("must-revalidate, post-check=0, pre-check=0");

return new ResponseEntity<>(export.getAsZipFile(), headers, HttpStatus.OK);
} catch (final Exception e) {
log.error("Error exporting project", e);
throw new ResponseStatusException(
Expand Down Expand Up @@ -578,29 +556,12 @@ public ResponseEntity<Project> importProject(@RequestPart("file") final Multipar
return ResponseEntity.badRequest().build();
}

final ProjectExport projectExport;
final ProjectExport projectExport = new ProjectExport();
try {
final ZipInputStream zipInputStream = new ZipInputStream(input.getInputStream());
final ZipEntry zipEntry = zipInputStream.getNextEntry();
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();

final byte[] buffer = new byte[1024];
int length;
while ((zipEntry != null) && (length = zipInputStream.read(buffer)) > 0) {
outputStream.write(buffer, 0, length);
}

final byte[] unzippedBytes = outputStream.toByteArray();

zipInputStream.closeEntry();
zipInputStream.close();
outputStream.close();

projectExport = objectMapper.readValue(unzippedBytes, ProjectExport.class);

} catch (final IOException e) {
e.printStackTrace();
return ResponseEntity.internalServerError().build();
projectExport.loadFromZipFile(input.getInputStream());
} catch (final Exception e) {
log.error("Error parsing project", e);
return ResponseEntity.badRequest().build();
}

final String userId = currentUserService.get().getId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
public class AssetExportDeserializer extends JsonDeserializer<AssetExport> {

@Override
public AssetExport deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException {
ObjectMapper mapper = (ObjectMapper) jp.getCodec();
JsonNode node = mapper.readTree(jp);
public AssetExport deserialize(final JsonParser jp, final DeserializationContext ctxt) throws IOException {
final ObjectMapper mapper = (ObjectMapper) jp.getCodec();
final JsonNode node = mapper.readTree(jp);

final String assetTypeStr = node.get("type").asText();
final AssetType assetType = AssetType.getAssetType(assetTypeStr, mapper);

TerariumAsset asset = mapper.treeToValue(node.get("asset"), assetType.getAssetClass());
final TerariumAsset asset = mapper.treeToValue(node.get("asset"), assetType.getAssetClass());

Map<String, FileExport> files = new HashMap<>();
if (node.has("files")) {
if (node.hasNonNull("files")) {
files = mapper.convertValue(node.get("files"), new TypeReference<Map<String, FileExport>>() {});
}

AssetExport export = new AssetExport();
final AssetExport export = new AssetExport();
export.setType(assetType);
export.setAsset(asset);
export.setFiles(files);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public ContentType deserialize(final JsonParser jp, final DeserializationContext
final JsonNode node = mapper.readTree(jp);

final String mimeTypeStr = node.get("mimeType").asText();
if (!node.has("charset")) {
if (!node.hasNonNull("charset")) {
return ContentType.create(mimeTypeStr);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package software.uncharted.terarium.hmiserver.models.dataservice;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import lombok.Data;
import lombok.experimental.Accessors;
Expand All @@ -12,5 +13,6 @@ public class FileExport {
@JsonDeserialize(using = ContentTypeDeserializer.class)
ContentType contentType;

@JsonIgnore
byte[] bytes;
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,19 @@ public class DocumentAsset extends TerariumAsset {

@Override
public List<String> getFileNames() {
final List<String> res = new ArrayList<>();
if (this.fileNames != null) {
for (final String fileName : fileNames) {
if (!res.contains(fileName)) {
res.add(fileName);
}
}
if (this.fileNames == null) {
this.fileNames = new ArrayList<>();
}

// ensure these are included in filenames
if (this.assets != null) {
for (final DocumentExtraction asset : assets) {
if (!res.contains(asset.getFileName())) {
res.add(asset.getFileName());
if (!this.fileNames.contains(asset.getFileName())) {
this.fileNames.add(asset.getFileName());
}
}
}
return res;
return this.fileNames;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
package software.uncharted.terarium.hmiserver.models.dataservice.project;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import lombok.Data;
import lombok.experimental.Accessors;
import software.uncharted.terarium.hmiserver.models.TerariumAsset;
import software.uncharted.terarium.hmiserver.models.dataservice.AssetExport;
import software.uncharted.terarium.hmiserver.models.dataservice.FileExport;
import software.uncharted.terarium.hmiserver.utils.AssetDependencyUtil;
import software.uncharted.terarium.hmiserver.utils.AssetDependencyUtil.AssetDependencyMap;

Expand All @@ -21,6 +30,88 @@ public class ProjectExport {
Project project;
List<AssetExport> assets = new ArrayList<>();

private byte[] readZipEntry(final ZipInputStream zipInputStream) throws IOException {
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final byte[] buffer = new byte[1024];
int count;
while ((count = zipInputStream.read(buffer)) != -1) {
baos.write(buffer, 0, count);
}
return baos.toByteArray();
}

public void loadFromZipFile(final InputStream inputStream) throws IOException {

final ObjectMapper objectMapper = new ObjectMapper();

final ZipInputStream zipInputStream = new ZipInputStream(inputStream);

// get the project json
ZipEntry zipEntry = zipInputStream.getNextEntry();
if (zipEntry == null || !zipEntry.getName().equals("project.json")) {
throw new IllegalArgumentException("Invalid project export file");
}

project = objectMapper.readValue(readZipEntry(zipInputStream), Project.class);

// iterate on assets
while ((zipEntry = zipInputStream.getNextEntry()) != null) {

// read the asset json
final AssetExport asset = objectMapper.readValue(readZipEntry(zipInputStream), AssetExport.class);

// read in the file payloads
for (final Map.Entry<String, FileExport> entry : asset.getFiles().entrySet()) {
zipEntry = zipInputStream.getNextEntry();
if (zipEntry == null) {
throw new IllegalArgumentException("Invalid project export file, expected a asset file payload");
}
final FileExport file = entry.getValue();
file.setBytes(readZipEntry(zipInputStream));
}

assets.add(asset);
}
}

public byte[] getAsZipFile() throws JsonProcessingException, IOException {

final ObjectMapper objectMapper = new ObjectMapper();

final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
final ZipOutputStream zipOutputStream = new ZipOutputStream(byteArrayOutputStream);

final byte[] projectBytes = objectMapper.writeValueAsBytes(project);

final ZipEntry zipEntry = new ZipEntry("project.json");
zipOutputStream.putNextEntry(zipEntry);
zipOutputStream.write(projectBytes);
zipOutputStream.closeEntry();

for (final AssetExport asset : assets) {

final byte[] assetBytes = objectMapper.writeValueAsBytes(asset);

final ZipEntry assetEntry = new ZipEntry(asset.getAsset().getId() + ".json");
zipOutputStream.putNextEntry(assetEntry);
zipOutputStream.write(assetBytes);
zipOutputStream.closeEntry();

for (final Map.Entry<String, FileExport> file : asset.getFiles().entrySet()) {
final byte[] fileBytes = file.getValue().getBytes();

final ZipEntry fileEntry = new ZipEntry(asset.getAsset().getId() + "/" + file.getKey());
zipOutputStream.putNextEntry(fileEntry);
zipOutputStream.write(fileBytes);
zipOutputStream.closeEntry();
}
}
zipOutputStream.finish();
zipOutputStream.close();

return byteArrayOutputStream.toByteArray();
}

public ProjectExport clone() {

final ProjectExport cloned = new ProjectExport();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public interface ProjectRepository extends PSCrudRepository<Project, UUID>, JpaS
+ "where "
+ " p.id in (:ids) "
+ " and p.deletedOn is null")
List<ProjectAndAssetAggregate> findByIdsWithAssets(final List<UUID> ids);
List<ProjectAndAssetAggregate> findByIdsWithAssets(@Param("ids") final List<UUID> ids);

@Query(value = "SELECT public_asset FROM project WHERE id = :id", nativeQuery = true)
Optional<Boolean> findPublicAssetByIdNative(@Param("id") UUID id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,10 @@ public Model createAsset(final Model asset, final UUID projectId, final Schema.P
// Execute the update request
uploadEmbeddings(created.getId(), embeddings, hasWritePermission);
} catch (final Exception e) {
log.error("Failed to update embeddings for document {}", created.getId(), e);
log.error("Failed to update embeddings for model {}", created.getId(), e);
}
})
.start();
;
}

return created;
Expand Down Expand Up @@ -181,11 +180,10 @@ public Optional<Model> updateAsset(
// Execute the update request
uploadEmbeddings(updated.getId(), embeddings, hasWritePermission);
} catch (final Exception e) {
log.error("Failed to update embeddings for document {}", updated.getId(), e);
log.error("Failed to update embeddings for model {}", updated.getId(), e);
}
})
.start();
;
}

return updatedOptional;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void testItCanCreateDocument() throws Exception {
Assertions.assertEquals(before.getId(), after.getId());
Assertions.assertNotNull(after.getId());
Assertions.assertNotNull(after.getCreatedOn());
Assertions.assertEquals(after.getFileNames().size(), 2);
Assertions.assertEquals(3, after.getFileNames().size());

Assertions.assertNotNull(after.getGrounding());
Assertions.assertNotNull(after.getGrounding().getId());
Expand Down

0 comments on commit 10a9866

Please sign in to comment.