Skip to content

Commit

Permalink
[Backport to main]fix security IT failure caused by weak password (#951
Browse files Browse the repository at this point in the history
…) (#1257)

* fix security IT failure caused by weak password (#951)

Signed-off-by: Yaliang Wu <[email protected]>

* Fix pre-trained model metadata parse exception

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: zane-neo <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
zane-neo and ylwu-amzn authored Aug 28, 2023
1 parent f93e789 commit a47db3a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,10 @@ public boolean isModelAllowed(MLRegisterModelInput registerModelInput, List mode
String version = registerModelInput.getVersion();
MLModelFormat modelFormat = registerModelInput.getModelFormat();
for (Object meta: modelMetaList) {
Map<String, Object> metaMap = (Map<String, Object>) meta;
String name = (String) metaMap.get("name");
Map<String, Object> versions = (Map<String, Object>) metaMap.get("versions");
Object versionObj = versions.get(version);
if (versionObj == null) return false;
Map<String, Object> versionMap = (Map<String, Object>) versionObj;
Object formatObj = versionMap.get("format");
if (formatObj == null) return false;
List<String> formats = (List<String>) formatObj;
if (name.equals(modelName) && versions.containsKey(version.toLowerCase(Locale.ROOT)) && formats.contains(modelFormat.toString().toLowerCase(Locale.ROOT))) {
String name = (String) ((Map<String, Object>)meta).get("name");
List<String> versions = (List) ((Map<String, Object>)meta).get("version");
List<String> formats = (List) ((Map<String, Object>)meta).get("format");
if (name.equals(modelName) && versions.contains(version.toLowerCase(Locale.ROOT)) && formats.contains(modelFormat.toString().toLowerCase(Locale.ROOT))) {
return true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public void testDownloadPrebuiltModelMetaList() throws PrivilegedActionException
.modelNodeIds(new String[]{"node_id1"})
.build();
List modelMetaList = modelHelper.downloadPrebuiltModelMetaList(taskId, registerModelInput);
assertEquals("huggingface/sentence-transformers/all-MiniLM-L12-v2", ((Map<String, String>)modelMetaList.get(0)).get("name"));
assertEquals("huggingface/sentence-transformers/all-distilroberta-v1", ((Map<String, String>)modelMetaList.get(0)).get("name"));
}

@Test
Expand Down
50 changes: 19 additions & 31 deletions plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class MLModelGroupRestIT extends MLCommonsRestTestCase {
public ExpectedException exceptionRule = ExpectedException.none();

private String modelGroupId;
private String password = "IntegTest@MLModelGroupRestIT123";

@Before
public void setup() throws IOException {
Expand All @@ -77,56 +78,43 @@ public void setup() throws IOException {
}
createSearchRole(indexSearchAccessRole, "*");

createUser(mlNoAccessUser, mlNoAccessUser, ImmutableList.of(opensearchBackendRole));
mlNoAccessClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlNoAccessUser,
mlNoAccessUser
).setSocketTimeout(60000).build();
createUser(mlNoAccessUser, password, ImmutableList.of(opensearchBackendRole));
mlNoAccessClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlNoAccessUser, password)
.setSocketTimeout(60000)
.build();

createUser(mlReadOnlyUser, mlReadOnlyUser, ImmutableList.of(opensearchBackendRole));
mlReadOnlyClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlReadOnlyUser,
mlReadOnlyUser
).setSocketTimeout(60000).build();
createUser(mlReadOnlyUser, password, ImmutableList.of(opensearchBackendRole));
mlReadOnlyClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlReadOnlyUser, password)
.setSocketTimeout(60000)
.build();

createUser(mlFullAccessNoIndexAccessUser, mlFullAccessNoIndexAccessUser, ImmutableList.of(opensearchBackendRole));
createUser(mlFullAccessNoIndexAccessUser, password, ImmutableList.of(opensearchBackendRole));
mlFullAccessNoIndexAccessClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlFullAccessNoIndexAccessUser,
mlFullAccessNoIndexAccessUser
password
).setSocketTimeout(60000).build();

createUser(mlFullAccessUser, mlFullAccessUser, ImmutableList.of(opensearchBackendRole));
mlFullAccessClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlFullAccessUser,
mlFullAccessUser
).setSocketTimeout(60000).build();
createUser(mlFullAccessUser, password, ImmutableList.of(opensearchBackendRole));
mlFullAccessClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlFullAccessUser, password)
.setSocketTimeout(60000)
.build();

createUser(mlNonAdminFullAccessWithoutBackendRoleUser, mlNonAdminFullAccessWithoutBackendRoleUser, ImmutableList.of());
createUser(mlNonAdminFullAccessWithoutBackendRoleUser, password, ImmutableList.of());
mlNonAdminFullAccessWithoutBackendRoleClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlNonAdminFullAccessWithoutBackendRoleUser,
mlNonAdminFullAccessWithoutBackendRoleUser
password
).setSocketTimeout(60000).build();

createUser(
mlNonOwnerFullAccessWithBackendRoleUser,
mlNonOwnerFullAccessWithBackendRoleUser,
ImmutableList.of(opensearchBackendRole)
);
createUser(mlNonOwnerFullAccessWithBackendRoleUser, password, ImmutableList.of(opensearchBackendRole));
mlNonOwnerFullAccessWithBackendRoleClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlNonOwnerFullAccessWithBackendRoleUser,
mlNonOwnerFullAccessWithBackendRoleUser
password
).setSocketTimeout(60000).build();

createRoleMapping("ml_read_access", ImmutableList.of(mlReadOnlyUser));
Expand Down
38 changes: 15 additions & 23 deletions plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
public ExpectedException exceptionRule = ExpectedException.none();

private String modelGroupId;
private String password = "IntegTest@SecureMLRestIT123";

@Before
public void setup() throws IOException, ParseException {
Expand All @@ -77,37 +78,28 @@ public void setup() throws IOException, ParseException {
}
createSearchRole(indexSearchAccessRole, "*");

createUser(mlNoAccessUser, mlNoAccessUser, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
mlNoAccessClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlNoAccessUser,
mlNoAccessUser
).setSocketTimeout(60000).build();
createUser(mlNoAccessUser, password, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
mlNoAccessClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlNoAccessUser, password)
.setSocketTimeout(60000)
.build();

createUser(mlReadOnlyUser, mlReadOnlyUser, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
mlReadOnlyClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlReadOnlyUser,
mlReadOnlyUser
).setSocketTimeout(60000).build();
createUser(mlReadOnlyUser, password, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
mlReadOnlyClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlReadOnlyUser, password)
.setSocketTimeout(60000)
.build();

createUser(mlFullAccessNoIndexAccessUser, mlFullAccessNoIndexAccessUser, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
createUser(mlFullAccessNoIndexAccessUser, password, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
mlFullAccessNoIndexAccessClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlFullAccessNoIndexAccessUser,
mlFullAccessNoIndexAccessUser
password
).setSocketTimeout(60000).build();

createUser(mlFullAccessUser, mlFullAccessUser, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
mlFullAccessClient = new SecureRestClientBuilder(
getClusterHosts().toArray(new HttpHost[0]),
isHttps(),
mlFullAccessUser,
mlFullAccessUser
).setSocketTimeout(60000).build();
createUser(mlFullAccessUser, password, new ArrayList<>(Arrays.asList(opensearchBackendRole)));
mlFullAccessClient = new SecureRestClientBuilder(getClusterHosts().toArray(new HttpHost[0]), isHttps(), mlFullAccessUser, password)
.setSocketTimeout(60000)
.build();

createRoleMapping("ml_read_access", new ArrayList<>(Arrays.asList(mlReadOnlyUser)));
createRoleMapping("ml_full_access", new ArrayList<>(Arrays.asList(mlFullAccessNoIndexAccessUser, mlFullAccessUser)));
Expand Down

0 comments on commit a47db3a

Please sign in to comment.