Skip to content

Commit

Permalink
Merge pull request #31 from ai-systems/feat/from_splits
Browse files Browse the repository at this point in the history
Feat/from splits
  • Loading branch information
juliarozanova authored Aug 23, 2021
2 parents d304330 + 403bd5e commit 45b283e
Show file tree
Hide file tree
Showing 16 changed files with 473 additions and 64 deletions.
10 changes: 10 additions & 0 deletions model3_test_control.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
1
0
1
1
1
1
0
1
1
1
1 change: 1 addition & 0 deletions probe_ably/core/flows/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .probe_from_dataloaders import probe_from_dataloaders
20 changes: 20 additions & 0 deletions probe_ably/core/flows/probe_from_dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import click
import prefect
from dynaconf import settings
from loguru import logger
from prefect import Flow
from prefect.engine.flow_runner import FlowRunner
from probe_ably.core.tasks.probing import TrainProbingTask
from probe_ably.core.tasks.metric_task import ProcessMetricTask

INPUT_FILE = "./tests/sample_files/test_input/multi_task_multi_model_with_control.json"
train_probing_task = TrainProbingTask()
process_metric_task = ProcessMetricTask()

def probe_from_dataloaders(config_dict, prepared_data):
with Flow("Running Probe") as flow1:
train_results = train_probing_task(prepared_data, config_dict["probing_setup"])
processed_results = process_metric_task(
train_results, config_dict["probing_setup"]
)
FlowRunner(flow=flow1).run
3 changes: 2 additions & 1 deletion probe_ably/core/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .abstract_model import AbstractModel
from .linear import LinearModel
from .mlp import MLPModel
from .mlp import MLPModel
from .model_params import ModelParams
10 changes: 5 additions & 5 deletions probe_ably/core/models/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def get_norm(self) -> Tensor:
return penalty


# def get_rank(self):
# ext_matrix = torch.cat([self.linear.weight, self.linear.bias.unsqueeze(-1)], dim=1)
# _, svd_matrix, _ = np.linalg.svd(ext_matrix.cpu().numpy())
# rank = np.sum(svd_matrix > 1e-3)
# return rank
def get_rank(self):
ext_matrix = torch.cat([self.linear.weight, self.linear.bias.unsqueeze(-1)], dim=1)
_, svd_matrix, _ = np.linalg.svd(ext_matrix.cpu().numpy())
rank = np.sum(svd_matrix > 1e-3)
return rank
50 changes: 50 additions & 0 deletions probe_ably/core/models/model_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Dict

class ModelParams():
def __init__(self)->Dict:
self.default_params = {
"probe_ably.core.models.linear.LinearModel": {
"params": [
{
"name": "dropout",
"type": "float_range",
"options": [0.0, 0.51]
},
{
"name": "alpha",
"type": "function",
"function_location": "probe_ably.core.utils.param_functions.nuclear_norm_alpha_generation",
"options": [-10.0, 3]
}]
},
"probe_ably.core.models.mlp.MLPModel": {
"params": [
{
"name": "hidden_size",
"type": "function",
"step": 0.01,
"function_location": "probe_ably.core.utils.param_functions.hidden_size_generation",
"options": [
2,
5
]
},
{
"name": "n_layers",
"type": "int_range",
"options": [
1,
2
]
},
{
"name": "dropout",
"type": "float_range",
"options": [
0.0,
0.5
]
}
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def generate_random_labels(unique_labels, labels_size):

return random_labels

@overrides
def run(self, input_data, input_labels):

unique_labels = self.get_unique_labels(input_labels)
Expand Down
1 change: 0 additions & 1 deletion probe_ably/core/tasks/metric_task/process_metric_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class ProcessMetricTask(Task):
@overrides
def run(
self, train_results: Dict[str, Dict], probing_configuration: Dict[str, Dict]
):
Expand Down
Loading

0 comments on commit 45b283e

Please sign in to comment.