Skip to content

Commit

Permalink
[CALCITE-6388] PsTableFunction throws NumberFormatException when the …
Browse files Browse the repository at this point in the history
…'user' column has spaces
  • Loading branch information
asolimando committed Jul 11, 2024
1 parent 73846cc commit 8581dba
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 85 deletions.
243 changes: 158 additions & 85 deletions plus/src/main/java/org/apache/calcite/adapter/os/PsTableFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Util;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
* Table function that executes the OS "ps" command
Expand All @@ -43,114 +46,184 @@ public class PsTableFunction {
Pattern.compile("([0-9]+):([0-9]+):([0-9]+)");
private static final Pattern HOUR_MINUTE_SECOND_PATTERN =
Pattern.compile("([0-9]+):([0-9]+)\\.([0-9]+)");
private static final Pattern NUMERIC_PATTERN = Pattern.compile("(\\d+)");

// it acts as a partial mapping, missing entries are the identity (e.g., "user" -> "user")
private static final ImmutableMap<String, String> UNIX_TO_MAC_PS_FIELDS =
ImmutableMap.<String, String>builder()
.put("pgrp", "pgid")
.put("start_time", "lstart")
.put("euid", "uid")
.build();

private static final List<String> PS_FIELD_NAMES =
ImmutableList.of("user",
"pid",
"ppid",
"pgrp",
"tpgid",
"stat",
"pcpu",
"pmem",
"vsz",
"rss",
"tty",
"start_time",
"time",
"euid",
"ruid",
"sess",
"comm");

private PsTableFunction() {
throw new AssertionError("Utility class should not be instantiated");
}

/**
* Class for parsing, line by line, the output of the ps command for a
* predefined list of parameters.
*/
@VisibleForTesting
protected static class LineParser implements Function1<String, Object[]> {

@Override public Object[] apply(String line) {
final String[] tokens = line.trim().split(" +");
final Object[] values = new Object[PS_FIELD_NAMES.size()];

if (tokens.length < PS_FIELD_NAMES.size()) {
throw new IllegalArgumentException(
"Expected at least " + PS_FIELD_NAMES.size() + ", got " + tokens.length);
}

int fieldIdx = 0;
int processedTokens = 0;
// more tokens than fields, either "user" or "comm" (or both) contain whitespaces, we assume
// usernames don't have numeric parts separated by whitespaces (e.g., "root 123"), therefore
// we stop whenever we find a numeric token assuming it's the "pid" and "user" is over
if (tokens.length > PS_FIELD_NAMES.size()) {
StringBuilder sb = new StringBuilder();
for (String field : tokens) {
if (NUMERIC_PATTERN.matcher(field).matches()) {
break;
}
processedTokens++;
sb.append(field).append(" ");
}
values[fieldIdx] =
field(PS_FIELD_NAMES.get(fieldIdx), sb.deleteCharAt(sb.length() - 1).toString());
fieldIdx++;
}

for (; fieldIdx < values.length - 1; fieldIdx++) {
try {
values[fieldIdx] = field(PS_FIELD_NAMES.get(fieldIdx), tokens[processedTokens++]);
} catch (RuntimeException e) {
throw new RuntimeException("while parsing value ["
+ tokens[fieldIdx] + "] of field [" + PS_FIELD_NAMES.get(fieldIdx)
+ "] in line [" + line + "]");
}
}

// spaces also in the "comm" part
if (processedTokens < tokens.length - 1) {
StringBuilder sb = new StringBuilder();
while (processedTokens < tokens.length) {
sb.append(tokens[processedTokens++]).append(" ");
}
values[fieldIdx] =
field(PS_FIELD_NAMES.get(fieldIdx), sb.deleteCharAt(sb.length() - 1).toString());
} else {
values[fieldIdx] = field(PS_FIELD_NAMES.get(fieldIdx), tokens[processedTokens]);
}
return values;
}

private Object field(String field, String value) {
switch (field) {
case "pid":
case "ppid":
case "pgrp": // linux only; macOS equivalent is "pgid"
case "pgid": // see "pgrp"
case "tpgid":
return Integer.valueOf(value);
case "pcpu":
case "pmem":
return (int) (Float.parseFloat(value) * 10f);
case "time":
final Matcher m1 =
MINUTE_SECOND_MILLIS_PATTERN.matcher(value);
if (m1.matches()) {
final long h = Long.parseLong(m1.group(1));
final long m = Long.parseLong(m1.group(2));
final long s = Long.parseLong(m1.group(3));
return h * 3600000L + m * 60000L + s * 1000L;
}
final Matcher m2 =
HOUR_MINUTE_SECOND_PATTERN.matcher(value);
if (m2.matches()) {
final long m = Long.parseLong(m2.group(1));
final long s = Long.parseLong(m2.group(2));
StringBuilder g3 = new StringBuilder(m2.group(3));
while (g3.length() < 3) {
g3.append("0");
}
final long millis = Long.parseLong(g3.toString());
return m * 60000L + s * 1000L + millis;
}
return 0L;
case "start_time": // linux only; macOS version is "lstart"
case "lstart": // see "start_time"
case "euid": // linux only; macOS equivalent is "uid"
case "uid": // see "euid"
default:
return value;
}
}
}

public static ScannableTable eval(boolean b) {
return new AbstractBaseScannableTable() {
@Override public Enumerable<@Nullable Object[]> scan(DataContext root) {
final RelDataType rowType = getRowType(root.getTypeFactory());
final List<String> fieldNames =
ImmutableList.copyOf(rowType.getFieldNames());
final List<String> fieldNames = ImmutableList.copyOf(rowType.getFieldNames());
final String[] args;
final String osName = System.getProperty("os.name");
final String osVersion = System.getProperty("os.version");
Util.discard(osVersion);
switch (osName) {
case "Mac OS X": // tested on version 10.12.5
args = new String[] {
"ps", "ax", "-o", "ppid=,pid=,pgid=,tpgid=,stat=,"
+ "user=,pcpu=,pmem=,vsz=,rss=,tty=,start=,time=,uid=,ruid=,"
+ "sess=,comm="};
"ps", "ax", "-o",
fieldNames.stream()
.map(s -> UNIX_TO_MAC_PS_FIELDS.getOrDefault(s, s) + "=")
.collect(Collectors.joining(","))};
break;
default:
args = new String[] {
"ps", "--no-headers", "axo", "ppid,pid,pgrp,"
+ "tpgid,stat,user,pcpu,pmem,vsz,rss,tty,start_time,time,euid,"
+ "ruid,sess,comm"};
"ps", "--no-headers", "axo", String.join(",", fieldNames)};
}
return Processes.processLines(args)
.select(
new Function1<String, Object[]>() {
@Override public Object[] apply(String line) {
final String[] fields = line.trim().split(" +");
final Object[] values = new Object[fieldNames.size()];
for (int i = 0; i < values.length; i++) {
try {
values[i] = field(fieldNames.get(i), fields[i]);
} catch (RuntimeException e) {
throw new RuntimeException("while parsing value ["
+ fields[i] + "] of field [" + fieldNames.get(i)
+ "] in line [" + line + "]");
}
}
return values;
}

private Object field(String field, String value) {
switch (field) {
case "pid":
case "ppid":
case "pgrp": // linux only; macOS equivalent is "pgid"
case "pgid": // see "pgrp"
case "tpgid":
return Integer.valueOf(value);
case "pcpu":
case "pmem":
return (int) (Float.valueOf(value) * 10f);
case "time":
final Matcher m1 =
MINUTE_SECOND_MILLIS_PATTERN.matcher(value);
if (m1.matches()) {
final long h = Long.parseLong(m1.group(1));
final long m = Long.parseLong(m1.group(2));
final long s = Long.parseLong(m1.group(3));
return h * 3600000L + m * 60000L + s * 1000L;
}
final Matcher m2 =
HOUR_MINUTE_SECOND_PATTERN.matcher(value);
if (m2.matches()) {
final long m = Long.parseLong(m2.group(1));
final long s = Long.parseLong(m2.group(2));
String g3 = m2.group(3);
while (g3.length() < 3) {
g3 = g3 + "0";
}
final long millis = Long.parseLong(g3);
return m * 60000L + s * 1000L + millis;
}
return 0L;
case "start_time": // linux only; macOS version is "lstart"
case "lstart": // see "start_time"
case "euid": // linux only; macOS equivalent is "uid"
case "uid": // see "euid"
default:
return value;
}
}
});
return Processes.processLines(args).select(new LineParser());
}

@Override public RelDataType getRowType(RelDataTypeFactory typeFactory) {
return typeFactory.builder()
.add("pid", SqlTypeName.INTEGER)
.add("ppid", SqlTypeName.INTEGER)
.add("pgrp", SqlTypeName.INTEGER)
.add("tpgid", SqlTypeName.INTEGER)
.add("stat", SqlTypeName.VARCHAR)
.add("user", SqlTypeName.VARCHAR)
.add("pcpu", SqlTypeName.DECIMAL, 3, 1)
.add("pmem", SqlTypeName.DECIMAL, 3, 1)
.add("vsz", SqlTypeName.INTEGER)
.add("rss", SqlTypeName.INTEGER)
.add("tty", SqlTypeName.VARCHAR)
.add("start_time", SqlTypeName.VARCHAR)
.add("time", TimeUnit.HOUR, -1, TimeUnit.SECOND, 0)
.add("euid", SqlTypeName.VARCHAR)
.add("ruid", SqlTypeName.VARCHAR)
.add("sess", SqlTypeName.VARCHAR)
.add("command", SqlTypeName.VARCHAR)
.add(PS_FIELD_NAMES.get(0), SqlTypeName.VARCHAR)
.add(PS_FIELD_NAMES.get(1), SqlTypeName.INTEGER)
.add(PS_FIELD_NAMES.get(2), SqlTypeName.INTEGER)
.add(PS_FIELD_NAMES.get(3), SqlTypeName.INTEGER)
.add(PS_FIELD_NAMES.get(4), SqlTypeName.INTEGER)
.add(PS_FIELD_NAMES.get(5), SqlTypeName.VARCHAR)
.add(PS_FIELD_NAMES.get(6), SqlTypeName.DECIMAL, 3, 1)
.add(PS_FIELD_NAMES.get(7), SqlTypeName.DECIMAL, 3, 1)
.add(PS_FIELD_NAMES.get(8), SqlTypeName.INTEGER)
.add(PS_FIELD_NAMES.get(9), SqlTypeName.INTEGER)
.add(PS_FIELD_NAMES.get(10), SqlTypeName.VARCHAR)
.add(PS_FIELD_NAMES.get(11), SqlTypeName.VARCHAR)
.add(PS_FIELD_NAMES.get(12), TimeUnit.HOUR, -1, TimeUnit.SECOND, 0)
.add(PS_FIELD_NAMES.get(13), SqlTypeName.VARCHAR)
.add(PS_FIELD_NAMES.get(14), SqlTypeName.VARCHAR)
.add(PS_FIELD_NAMES.get(15), SqlTypeName.VARCHAR)
.add(PS_FIELD_NAMES.get(16), SqlTypeName.VARCHAR)
.build();
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.adapter.os;

import com.google.common.collect.ImmutableList;

import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;

/**
* Unit tests for the ps (process status) table function.
*/
class PsTableFunctionTest {

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6388">[CALCITE-6388]
* PsTableFunction throws NumberFormatException when the 'user' column has spaces</a>.
*/
@Test void testPsInfoParsing() {
final List<String> input = new ArrayList<>();
input.add("startup user 56399 1 56399 0 S 0.0 0.0 410348128 6672 ??"
+ " 3:25PM 0:00.22 501 501 0 /usr/lib exec/trustd");
input.add("root 1 107 107 0 Ss 0.0 0.0 410142784 4016 ??"
+ " 11Apr24 0:52.32 0 0 0 "
+ "/System/Library/PrivateFrameworks/Uninstall.framework/Resources/uninstalld");
input.add("user.name 1 1661 1661 0 S 0.7 0.2 412094800 75232 ?? "
+ "11Apr24 325:33.63 775020228 775020228 0 "
+ "/System/Library/CoreServices/ControlCenter app/Contents/MacOS/ControlCenter");

final List<List<Object>> output =
ImmutableList.of(
Arrays.asList("startup user", 56399, 1, 56399, 0, "S", 0, 0, "410348128", "6672", "??",
"3:25PM", 220L, "501", "501", "0", "/usr/lib exec/trustd"),
Arrays.asList("root", 1, 107, 107, 0, "Ss", 0, 0, "410142784", "4016", "??",
"11Apr24", 52320L, "0", "0", "0",
"/System/Library/PrivateFrameworks/Uninstall.framework/Resources/uninstalld"),
Arrays.asList("user.name", 1, 1661, 1661, 0, "S", 7, 2, "412094800", "75232", "??",
"11Apr24", 19533630L, "775020228", "775020228", "0",
"/System/Library/CoreServices/ControlCenter app/Contents/MacOS/ControlCenter"));

final Map<String, List<Object>> testValues = new HashMap<>();
for (int i = 0; i < input.size(); i++) {
testValues.put(input.get(i), output.get(i));
}

final PsTableFunction.LineParser psLineParser = new PsTableFunction.LineParser();
for (Map.Entry<String, List<Object>> e : testValues.entrySet()) {
assertThat(psLineParser.apply(e.getKey()), is(e.getValue().toArray()));
}
}
}

0 comments on commit 8581dba

Please sign in to comment.