diff --git a/tasks/crud/CRUD_API_README.md b/tasks/crud/CRUD_API_README.md new file mode 100644 index 0000000..2c7657c --- /dev/null +++ b/tasks/crud/CRUD_API_README.md @@ -0,0 +1,144 @@ + +# 📦 CRUD API with Golang, MongoDB & Gorilla Mux + +A lightweight, performant RESTful API in Go designed to handle basic Create, Read, Update, and Delete (CRUD) operations over file metadata using MongoDB as the database and Gorilla Mux as the router. + +--- + +## 🌟 Overview + +This project implements a modular and extensible file metadata management service. It provides clean REST endpoints to perform CRUD operations on a MongoDB collection. Designed to be simple yet scalable, the API can be easily integrated into larger systems such as content managers, file explorers, or intelligent search frameworks like FileNest. + +--- + +## 🎯 Key Features + +- ⚙️ RESTful API with proper HTTP methods and status codes +- 🛡️ Input validation and error handling +- 🗃️ MongoDB integration with BSON support for efficient storage +- 🔄 Endpoints for Create, Read (all/single), Update, and Delete operations +- 📎 Built with modular code for easy extension +- 📫 Tested with Postman and Curl + +--- + +## 🏗️ Project Architecture + +``` +crud-api/ +├── controllers/ # API logic and handler functions +├── models/ # Data model (MongoDB schema) +├── router/ # Gorilla Mux routing definitions +├── .env # MongoDB credentials and configs +├── go.mod # Module dependencies +├── go.sum # Module checksums +├── main.go # Application entrypoint +└── README.md # This file +``` + +--- + +## ⚙️ Tech Stack + +| Layer | Tech | +|----------------|-------------------------------| +| Language | Golang 1.21+ | +| Router | Gorilla Mux | +| Database | MongoDB Atlas (or local) | +| Driver | Mongo Go Driver | +| Testing Tool | Postman, Curl | + +--- + +## 🚀 Quick Start + +### 📋 Prerequisites + +- Go 1.21 or higher +- MongoDB Atlas or local MongoDB instance +- Git + +### 📥 Installation + +```bash +# Download Go dependencies +go mod tidy +``` + +### ⚙️ Environment Setup + +Create a `.env` file in the root with the following: + +``` +MONGO_URI=mongodb+srv://:@cluster0.mongodb.net/?retryWrites=true&w=majority +``` + +### ▶️ Run the Server + +```bash +go run main.go +``` + +> Server starts on `http://localhost:8000` + +--- + +## 📡 API Endpoints + +| Method | Endpoint | Description | +|--------|---------------------|------------------------------| +| GET | `/api/files` | Get all file records | +| GET | `/api/files/{id}` | Get file by ID | +| POST | `/api/files` | Create new file record | +| PUT | `/api/files/{id}` | Update file by ID | +| DELETE | `/api/files/{id}` | Delete file by ID | + +--- + +## 📄 Data Model (FileMetadata) + +```go +type FileMetadata struct { + ID primitive.ObjectID `json:"id,omitempty" bson:"_id,omitempty"` + FileName string `json:"filename" bson:"filename"` + FilePath string `json:"filepath" bson:"filepath"` + FileSize int64 `json:"filesize" bson:"filesize"` + ContentType string `json:"content_type" bson:"content_type"` + CreatedAt time.Time `json:"created_at" bson:"created_at"` +} +``` + +--- + +## 🧪 Testing + +Use [Postman](https://www.postman.com/) or Curl: + +### Create File (POST) + +```bash +curl -X POST http://localhost:8000/api/files \ + -H "Content-Type: application/json" \ + -d '{"filename":"test.pdf","filepath":"/docs/test.pdf","filesize":1024,"content_type":"application/pdf"}' +``` + +### Get All Files (GET) + +```bash +curl http://localhost:8000/api/files +``` + +### Update File (PUT) + +```bash +curl -X PUT http://localhost:8000/api/files/ \ + -H "Content-Type: application/json" \ + -d '{"filename":"updated.pdf"}' +``` + +### Delete File (DELETE) + +```bash +curl -X DELETE http://localhost:8000/api/files/ +``` + diff --git a/tasks/crud/controllers/filecontroller.go b/tasks/crud/controllers/filecontroller.go new file mode 100644 index 0000000..87d49f8 --- /dev/null +++ b/tasks/crud/controllers/filecontroller.go @@ -0,0 +1,205 @@ +package controllers + +import ( + "context" + "crud-api/models" + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "time" + + "github.com/gorilla/mux" + "github.com/joho/godotenv" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +const ( + dbName = "FileDb" + colName = "Files" +) + +var collection *mongo.Collection + +// Initializes MongoDB connection and sets the collection. +func init() { + if err := godotenv.Load(); err != nil { + log.Fatal("Error loading .env file") + } + + connectionString := os.Getenv("MONGODB_URI") + clientOptions := options.Client().ApplyURI(connectionString) + + client, err := mongo.Connect(context.TODO(), clientOptions) + if err != nil { + log.Fatal("Mongo connection error:", err) + } + + fmt.Println("MongoDB connection successful") + collection = client.Database(dbName).Collection(colName) +} + +// Inserts a new file document into MongoDB. +func createFile(file *models.FileMetadata) error { + file.ID = primitive.NewObjectID() + file.CreatedAt = time.Now() + file.UpdatedAt = time.Now() + + _, err := collection.InsertOne(context.Background(), file) + return err +} + +// Updates a file document by ID with non-empty fields and returns updated doc. +func updateFile(file *models.FileMetadata, id string) error { + objID, err := primitive.ObjectIDFromHex(id) + if err != nil { + return fmt.Errorf("invalid ID: %v", err) + } + + updateFields := bson.M{} + if file.FileName != "" { + updateFields["filename"] = file.FileName + } + if file.FilePath != "" { + updateFields["filepath"] = file.FilePath + } + if file.FileSize != 0 { + updateFields["filesize"] = file.FileSize + } + if file.ContentType != "" { + updateFields["content_type"] = file.ContentType + } + updateFields["updated_at"] = time.Now() + + _, err = collection.UpdateOne(context.Background(), bson.M{"_id": objID}, bson.M{"$set": updateFields}) + if err != nil { + return err + } + + return collection.FindOne(context.Background(), bson.M{"_id": objID}).Decode(file) +} + +// Deletes a single file document by ID. +func deleteOneFile(fileID string) error { + objID, err := primitive.ObjectIDFromHex(fileID) + if err != nil { + return err + } + + _, err = collection.DeleteOne(context.Background(), bson.M{"_id": objID}) + return err +} + +// Retrieves all file documents from the collection. +func getAllFiles() ([]models.FileMetadata, error) { + cursor, err := collection.Find(context.Background(), bson.M{}) + if err != nil { + return nil, err + } + defer cursor.Close(context.Background()) + + var files []models.FileMetadata + for cursor.Next(context.Background()) { + var file models.FileMetadata + if err := cursor.Decode(&file); err != nil { + log.Println("Decode error:", err) + continue + } + files = append(files, file) + } + + return files, nil +} + +// Retrieves a single file document by ID. +func getFileByID(fileID string) (*models.FileMetadata, error) { + objID, err := primitive.ObjectIDFromHex(fileID) + if err != nil { + return nil, err + } + + var file models.FileMetadata + err = collection.FindOne(context.Background(), bson.M{"_id": objID}).Decode(&file) + return &file, err +} + +// HTTP: Returns all files. +func GetAllFiles(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + files, err := getAllFiles() + if err != nil { + http.Error(w, "Failed to fetch files", http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(files) +} + +// HTTP: Returns a file by ID. +func GetFile(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + id := mux.Vars(r)["id"] + + file, err := getFileByID(id) + if err != nil { + http.Error(w, "File not found", http.StatusNotFound) + return + } + + json.NewEncoder(w).Encode(file) +} + +// HTTP: Creates a new file. +func CreateFile(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var file models.FileMetadata + if err := json.NewDecoder(r.Body).Decode(&file); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + if err := createFile(&file); err != nil { + http.Error(w, "Failed to create file", http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(file) +} + +// HTTP: Updates a file by ID. +func UpdateFile(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + id := mux.Vars(r)["id"] + + var file models.FileMetadata + if err := json.NewDecoder(r.Body).Decode(&file); err != nil { + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + if err := updateFile(&file, id); err != nil { + http.Error(w, "Failed to update file", http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(file) +} + +// HTTP: Deletes a file by ID. +func DeleteFile(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + id := mux.Vars(r)["id"] + + if err := deleteOneFile(id); err != nil { + http.Error(w, "Failed to delete file", http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(map[string]string{"deleted_id": id}) +} diff --git a/tasks/crud/go.mod b/tasks/crud/go.mod new file mode 100644 index 0000000..c0f262b --- /dev/null +++ b/tasks/crud/go.mod @@ -0,0 +1,22 @@ +module crud-api + +go 1.22.2 + +require ( + github.com/gorilla/mux v1.8.1 + github.com/joho/godotenv v1.5.1 + go.mongodb.org/mongo-driver v1.17.4 +) + +require ( + github.com/golang/snappy v1.0.0 // indirect + github.com/klauspost/compress v1.16.7 // indirect + github.com/montanaflynn/stats v0.7.1 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect + golang.org/x/crypto v0.33.0 // indirect + golang.org/x/sync v0.11.0 // indirect + golang.org/x/text v0.22.0 // indirect +) diff --git a/tasks/crud/go.sum b/tasks/crud/go.sum new file mode 100644 index 0000000..2770696 --- /dev/null +++ b/tasks/crud/go.sum @@ -0,0 +1,54 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= +github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= +github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mongodb.org/mongo-driver v1.17.4 h1:jUorfmVzljjr0FLzYQsGP8cgN/qzzxlY9Vh0C9KFXVw= +go.mongodb.org/mongo-driver v1.17.4/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/tasks/crud/main b/tasks/crud/main new file mode 100755 index 0000000..b03c2cb Binary files /dev/null and b/tasks/crud/main differ diff --git a/tasks/crud/main.go b/tasks/crud/main.go new file mode 100644 index 0000000..bd67543 --- /dev/null +++ b/tasks/crud/main.go @@ -0,0 +1,21 @@ +package main + +import ( + "crud-api/router" + "fmt" + "net/http" + "os" +) + +func main() { + port := os.Getenv("PORT") + if port == "" { + port = "8000" + } + + fmt.Println("Server running at http://localhost:" + port) + err := http.ListenAndServe(":"+port, router.Router()) + if err != nil { + fmt.Println("Failed to start server:", err) + } +} diff --git a/tasks/crud/router/router.go b/tasks/crud/router/router.go new file mode 100644 index 0000000..add9a43 --- /dev/null +++ b/tasks/crud/router/router.go @@ -0,0 +1,20 @@ +package router + +import ( + "crud-api/controllers" + + "github.com/gorilla/mux" +) + +func Router() *mux.Router { + router := mux.NewRouter() + + // File API routes + router.HandleFunc("/api/files", controllers.GetAllFiles).Methods("GET") // Get all files + router.HandleFunc("/api/files/{id}", controllers.GetFile).Methods("GET") // Get single file by ID + router.HandleFunc("/api/files", controllers.CreateFile).Methods("POST") // Create a new file + router.HandleFunc("/api/files/{id}", controllers.UpdateFile).Methods("PUT") // Update a file by ID + router.HandleFunc("/api/files/{id}", controllers.DeleteFile).Methods("DELETE") // Delete a file by ID\ + + return router +} diff --git a/tasks/filenest-xs/README.md b/tasks/filenest-xs/README.md new file mode 100644 index 0000000..f84d383 --- /dev/null +++ b/tasks/filenest-xs/README.md @@ -0,0 +1,101 @@ +# FileNest-XS + +A minimal, concurrent file embedding and D1TV assignment service in Go. Indexes files from a directory, generates semantic embeddings, assigns each file to a random D1TV centroid, and stores metadata and embeddings in PostgreSQL. + +--- + +## 🚀 Features + +- **Concurrent File Processing:** Multi-worker pipeline for fast indexing. +- **Semantic Embeddings:** Generates 128-dim float embeddings (random for demo). +- **D1TV Assignment:** Assigns files to the nearest centroid using cosine similarity. +- **PostgreSQL Storage:** Stores file metadata, embeddings, and D1TV assignments. +- **Graceful Shutdown:** Handles interrupts and cleans up resources. + +--- + +## 🏗️ Project Structure + +``` +filenest-xs/ +├── main.go # Entry point, worker pool, orchestration +├── d1tv/ +│ └── d1tv.go # Embedding generation, centroid assignment +├── database/ +│ └── db.go # DB connection, migration, upsert logic +├── model/ +│ └── model.go # FileIndex GORM model +├── sample_files/ # Example files to index +│ ├── e1.txt ... e5.txt +├── go.mod, go.sum # Go dependencies +├── .env # (Optional) Environment variables +``` + +--- + +## ⚙️ Setup & Usage + +### 1. Prerequisites + +- Go 1.23+ +- PostgreSQL (default: user `postgres`, password `postgres`, db `filenest_xs`) +- (Optional) Directory of sample files + +### 2. Database Setup + +Create the database: + +```sh +createdb -h localhost -U postgres filenest_xs +``` + +### 3. Install Dependencies + +```sh +go mod tidy +``` + +### 4. Run the Indexer + +```sh +go run main.go -dir=./sample_files -workers=5 +``` + +- `-dir`: Directory containing files (default: `./sample_files`) +- `-workers`: Number of concurrent workers (default: 5) + +--- + +## 🗃️ Database Schema + +Table: `file_index` + +| Column | Type | Description | +|-----------|--------------|------------------------------------| +| id | SERIAL | Primary key | +| file_name | TEXT | File name (unique with path) | +| file_path | TEXT | Full file path (unique with name) | +| embedding | float8[] | 128-dim embedding vector | +| d1tv_id | INTEGER | Assigned D1TV centroid | +| indexed_at| TIMESTAMP | Time of indexing | + +--- + +## 🧩 How It Works + +1. **main.go**: Reads files, spawns workers, orchestrates the pipeline. +2. **d1tv/d1tv.go**: Generates random embeddings, assigns to nearest centroid. +3. **database/db.go**: Handles DB connection, migration, and upsert. +4. **model/model.go**: GORM model for file metadata and embeddings. + +--- + +## 📝 License + +MIT License. See [LICENSE](../../LICENSE) for details. + +--- + +## 🙏 Acknowledgments + +Inspired by the FileNest project (Summer RAID 2025, IIT Jodhpur). \ No newline at end of file diff --git a/tasks/filenest-xs/d1tv/d1tv.go b/tasks/filenest-xs/d1tv/d1tv.go new file mode 100644 index 0000000..88f47bc --- /dev/null +++ b/tasks/filenest-xs/d1tv/d1tv.go @@ -0,0 +1,69 @@ +package d1tv + +import ( + "math" + "math/rand" +) + +const EmbeddingSize = 128 +const NumCentroids = 10 + +// D1TVCentroids holds the randomly initialized centroids for D1TV assignment. +var D1TVCentroids [NumCentroids][]float64 + +// rng is a deterministic random number generator for reproducibility. +var rng = rand.New(rand.NewSource(42)) + +// init initializes the D1TV centroids with random vectors. +func init() { + for i := 0; i < NumCentroids; i++ { + vec := make([]float64, EmbeddingSize) + for j := range vec { + vec[j] = rng.Float64() + } + D1TVCentroids[i] = vec + } +} + +// GenerateEmbeddings creates a random normalized embedding vector for the input content. +func GenerateEmbeddings(content string) []float64 { + const dim = 128 + vec := make([]float64, dim) + var norm float64 + + for i := 0; i < dim; i++ { + vec[i] = rng.Float64() + norm += vec[i] * vec[i] + } + + norm = math.Sqrt(norm) + for i := range vec { + vec[i] /= norm + } + + return vec +} + +// CosineSimilarity computes the cosine similarity between two vectors. +func CosineSimilarity(a, b []float64) float64 { + dot, normA, normB := 0.0, 0.0, 0.0 + for i := range a { + dot += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + return dot / (math.Sqrt(normA) * math.Sqrt(normB)) +} + +// AssignToD1TV assigns the given vector to the closest centroid and returns its index and similarity. +func AssignToD1TV(vec []float64) (int, float64) { + maxSim, assigned := -1.0, 0 + for i, centroid := range D1TVCentroids { + sim := CosineSimilarity(vec, centroid) + if sim > maxSim { + maxSim = sim + assigned = i + } + } + return assigned, maxSim +} diff --git a/tasks/filenest-xs/database/db.go b/tasks/filenest-xs/database/db.go new file mode 100644 index 0000000..015aff3 --- /dev/null +++ b/tasks/filenest-xs/database/db.go @@ -0,0 +1,67 @@ +package database + +import ( + "filenest-xs/model" + "fmt" + "os" + + "github.com/joho/godotenv" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// DbInit establishes a connection to the PostgreSQL database, +func DbInit() (*gorm.DB, error) { + err := godotenv.Load() + if err != nil { + return nil, fmt.Errorf("failed to load .env file: %w", err) + } + dsn := fmt.Sprintf("postgres://%v:%v@%v:%v/%v?sslmode=disable", + os.Getenv("POSTGRES_USER"), + os.Getenv("POSTGRES_PASS"), + os.Getenv("POSTGRES_HOST"), + os.Getenv("POSTGRES_PORT"), + os.Getenv("POSTGRES_DB_NAME"), + ) + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + sqlDB, err := db.DB() + if err != nil { + return nil, fmt.Errorf("failed to get generic DB object: %w", err) + } + + if err := sqlDB.Ping(); err != nil { + return nil, fmt.Errorf("database ping failed: %w", err) + } + + fmt.Println("Database connection successful!") + return db, nil +} + +// UpsertFileIndex inserts or updates a FileIndex record in the database. +func UpsertFileIndex(db *gorm.DB, fileIndex *model.FileIndex) error { + result := db.Clauses( + clause.OnConflict{ + Columns: []clause.Column{ + {Name: "file_name"}, + {Name: "file_path"}, + }, + DoUpdates: clause.Assignments(map[string]interface{}{ + "embedding": fileIndex.Embedding, + "d1tv_id": fileIndex.D1TVID, + "indexed_at": fileIndex.IndexedAt, + }), + }, + ).Create(fileIndex) + + if result.Error != nil { + return result.Error + } + + return nil +} diff --git a/tasks/filenest-xs/go.mod b/tasks/filenest-xs/go.mod new file mode 100644 index 0000000..ba996dd --- /dev/null +++ b/tasks/filenest-xs/go.mod @@ -0,0 +1,27 @@ +module filenest-xs + +go 1.23.0 + +toolchain go1.23.10 + +require ( + github.com/jackc/pgx/v5 v5.7.5 // indirect + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.30.0 +) + +require ( + github.com/joho/godotenv v1.5.1 + github.com/lib/pq v1.10.9 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + golang.org/x/crypto v0.37.0 // indirect + golang.org/x/sync v0.13.0 // indirect + golang.org/x/text v0.24.0 // indirect +) diff --git a/tasks/filenest-xs/go.sum b/tasks/filenest-xs/go.sum new file mode 100644 index 0000000..557c10e --- /dev/null +++ b/tasks/filenest-xs/go.sum @@ -0,0 +1,40 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= +github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= +gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= diff --git a/tasks/filenest-xs/main.go b/tasks/filenest-xs/main.go new file mode 100644 index 0000000..b26c7c1 --- /dev/null +++ b/tasks/filenest-xs/main.go @@ -0,0 +1,137 @@ +package main + +import ( + "context" + "filenest-xs/d1tv" + "filenest-xs/database" + "filenest-xs/model" + "flag" + "fmt" + "log" + "os" + "os/signal" + "path/filepath" + "sync" + "syscall" + + "gorm.io/gorm" +) + +// FileJob represents a file to be processed +type FileJob struct { + FileName string + FilePath string +} + +func main() { + // Parse command-line flags + dirPath := flag.String("dir", "./sample_files", "Directory of files to process") + numWorkers := flag.Int("workers", 5, "Number of concurrent workers") + flag.Parse() + + // Setup context for graceful shutdown + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Handle interrupt signals (e.g., Ctrl+C) + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt, syscall.SIGTERM) + go func() { + <-sigs + fmt.Println("\nShutting down...") + cancel() + }() + + // Initialize database + db, err := database.DbInit() + if err != nil { + log.Fatal(err) + } + sqlDB, err := db.DB() + if err != nil { + log.Fatal(err) + } + defer sqlDB.Close() + + // Read files from directory + entries, err := os.ReadDir(*dirPath) + if err != nil { + log.Fatalf("Failed to read directory: %v", err) + } + + // Start worker pool + jobs := make(chan FileJob) + var wg sync.WaitGroup + for i := 0; i < *numWorkers; i++ { + wg.Add(1) + go startWorker(ctx, db, jobs, &wg, i+1) // Pass worker ID (1-based) + } + + // Send jobs to workers +LoopToBreak: + for _, entry := range entries { + if entry.IsDir() { + continue + } + select { + case <-ctx.Done(): + break LoopToBreak + case jobs <- FileJob{ + FileName: entry.Name(), + FilePath: filepath.Join(*dirPath, entry.Name()), + }: + } + } + + close(jobs) + wg.Wait() + + if ctx.Err() != nil { + fmt.Println("Processing interrupted before completion.") + } else { + fmt.Println("All files processed.") + } +} + +// startWorker processes jobs from the channel +func startWorker(ctx context.Context, db *gorm.DB, jobs <-chan FileJob, wg *sync.WaitGroup, workerID int) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case job, ok := <-jobs: + if !ok { + return + } + processFile(job, db, workerID) + } + } +} + +// processFile reads, embeds, assigns, and saves file info +func processFile(job FileJob, db *gorm.DB, workerID int) { + content, err := os.ReadFile(job.FilePath) + if err != nil { + log.Printf("[Worker-%d] Error reading %s: %v", workerID, job.FilePath, err) + return + } + + vec := d1tv.GenerateEmbeddings(string(content)) + treeID, similarity := d1tv.AssignToD1TV(vec) + + // Log in requested format + fmt.Printf("[Worker-%d] Processed file: %s → D1TV: %d (similarity: %.2f)\n", + workerID, job.FileName, treeID, similarity) + + fileIndex := &model.FileIndex{ + FileName: job.FileName, + FilePath: job.FilePath, + Embedding: vec, + D1TVID: treeID, + } + + if err := database.UpsertFileIndex(db, fileIndex); err != nil { + log.Printf("[Worker-%d] DB error for %s: %v", workerID, job.FileName, err) + } +} diff --git a/tasks/filenest-xs/model/model.go b/tasks/filenest-xs/model/model.go new file mode 100644 index 0000000..52042ce --- /dev/null +++ b/tasks/filenest-xs/model/model.go @@ -0,0 +1,20 @@ +package model + +import ( + "time" + + "github.com/lib/pq" +) + +type FileIndex struct { + ID uint `gorm:"primaryKey"` + FileName string + FilePath string + Embedding pq.Float64Array `gorm:"type:float8[]"` + D1TVID int `gorm:"column:d1tv_id"` + IndexedAt time.Time `gorm:"autoCreateTime"` +} + +func (FileIndex) TableName() string { + return "file_index" +} diff --git a/tasks/filenest-xs/sample_files/e1.txt b/tasks/filenest-xs/sample_files/e1.txt new file mode 100644 index 0000000..6a3c340 --- /dev/null +++ b/tasks/filenest-xs/sample_files/e1.txt @@ -0,0 +1 @@ +example 1 \ No newline at end of file diff --git a/tasks/filenest-xs/sample_files/e2.txt b/tasks/filenest-xs/sample_files/e2.txt new file mode 100644 index 0000000..b1b8add --- /dev/null +++ b/tasks/filenest-xs/sample_files/e2.txt @@ -0,0 +1 @@ +example 2 \ No newline at end of file diff --git a/tasks/filenest-xs/sample_files/e3.txt b/tasks/filenest-xs/sample_files/e3.txt new file mode 100644 index 0000000..1e2ce4d --- /dev/null +++ b/tasks/filenest-xs/sample_files/e3.txt @@ -0,0 +1 @@ +example 3 \ No newline at end of file diff --git a/tasks/filenest-xs/sample_files/e4.txt b/tasks/filenest-xs/sample_files/e4.txt new file mode 100644 index 0000000..bff3701 --- /dev/null +++ b/tasks/filenest-xs/sample_files/e4.txt @@ -0,0 +1 @@ +example 4 diff --git a/tasks/filenest-xs/sample_files/e5.txt b/tasks/filenest-xs/sample_files/e5.txt new file mode 100644 index 0000000..8255417 --- /dev/null +++ b/tasks/filenest-xs/sample_files/e5.txt @@ -0,0 +1 @@ +example 5 diff --git a/tasks/image_classifier/Dockerfile b/tasks/image_classifier/Dockerfile new file mode 100644 index 0000000..5c48651 --- /dev/null +++ b/tasks/image_classifier/Dockerfile @@ -0,0 +1,18 @@ +# Base image with PyTorch + CUDA support +FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime + +# Set working directory inside container +WORKDIR /usr/src/app/image_classifier + +# Copy contents from the image_classifier directory on host into container +COPY ./tasks/image_classifier/ . + +# Upgrade pip and install Python dependencies +RUN pip install --upgrade pip && \ + pip install -r requirements.txt + +# Optional: Prevent Ray memory monitor errors in Docker +ENV RAY_memory_monitor_refresh_ms=0 + +# Default command +CMD ["python", "main.py"] diff --git a/tasks/image_classifier/README.md b/tasks/image_classifier/README.md new file mode 100644 index 0000000..4696ebc --- /dev/null +++ b/tasks/image_classifier/README.md @@ -0,0 +1,94 @@ +# Image Classifier + +A PyTorch-based image classifier for MNIST, featuring experiment tracking, hyperparameter tuning, configuration management, and Docker support. + +## Features + +- MobileNetV2-inspired model in PyTorch +- Data loading and preprocessing with torchvision +- Model summary with torchinfo +- Experiment tracking with Weights & Biases (wandb) +- Hyperparameter tuning with Ray Tune +- Configuration management using Hydra and OmegaConf +- **Docker support** for reproducible environments + +## Project Structure + +``` +tasks/image_classifier/ +├── assets/ # Screenshots and visualizations +├── conf/ +│ └── config.yaml # Training and experiment config +├── data.py # Data loaders for MNIST +├── main.py # Entry point, Ray Tune + Hydra integration +├── model.py # Model definition +├── model.ipynb # Jupyter notebook for model exploration +├── requirements.txt # Python dependencies +├── train.py # Training and validation logic +├── Dockerfile # Docker container specification +└── README.md # Project documentation +``` + +## Installation + +1. Clone the repository. +2. Install dependencies: + ```bash + pip install -r requirements.txt + ``` + +## Usage + +Run the main experiment (with Ray Tune and Hydra): +```bash +python main.py +``` + +You can modify hyperparameters and settings in [`conf/config.yaml`](conf/config.yaml). + +## Configuration + +- All training and experiment settings are managed via Hydra config files in the `conf/` directory. +- You can override config values from the command line, e.g.: + ```bash + python main.py lr=0.01 batch_size=32 + ``` + +## Docker + +Build and run the project in a Docker container: + +```bash +docker build -t image-classifier . +docker run --gpus all --rm image-classifier +``` + +- The Dockerfile uses the official PyTorch image with CUDA support. +- All dependencies are installed inside the container. + +## Dependencies + +- torch +- torchvision +- torchinfo +- wandb +- ray[tune] +- hydra-core +- omegaconf + +## Results & Metrics + +After training, you can view metrics and compare runs in your [Weights & Biases dashboard](https://wandb.ai/). + +### 📊 Metrics Screenshots + +Paste your screenshots into the `assets/` folder and reference them below: + +- **WandB Metrics:** + ![WandB Metrics Screenshot](./assets/wandb_metrics.png) + +- **Model Summary:** + ![Model Summary](./assets/model_summary.png) + +- **Ray Tune Metrics:** + ![Ray Tune Metrics](./assets/metrics.png) diff --git a/tasks/image_classifier/assets/metrics.png b/tasks/image_classifier/assets/metrics.png new file mode 100644 index 0000000..42404f2 Binary files /dev/null and b/tasks/image_classifier/assets/metrics.png differ diff --git a/tasks/image_classifier/assets/model_summary.png b/tasks/image_classifier/assets/model_summary.png new file mode 100644 index 0000000..3e70e23 Binary files /dev/null and b/tasks/image_classifier/assets/model_summary.png differ diff --git a/tasks/image_classifier/assets/wandb_metrics.png b/tasks/image_classifier/assets/wandb_metrics.png new file mode 100644 index 0000000..d5caa43 Binary files /dev/null and b/tasks/image_classifier/assets/wandb_metrics.png differ diff --git a/tasks/image_classifier/conf/config.yaml b/tasks/image_classifier/conf/config.yaml new file mode 100644 index 0000000..d779e3b --- /dev/null +++ b/tasks/image_classifier/conf/config.yaml @@ -0,0 +1,11 @@ +# Training configuration +lr: 0.001 +batch_size: 64 +max_num_epochs: 10 +num_trials: 2 +device: cuda # Automatically use CUDA; you can also use: ${oc.env:CUDA_VISIBLE_DEVICES, "cpu"} + + +# Weights & Biases configuration +wandb_project: "Img_classifier" +wandb_mode: "online" \ No newline at end of file diff --git a/tasks/image_classifier/data.py b/tasks/image_classifier/data.py new file mode 100644 index 0000000..840971b --- /dev/null +++ b/tasks/image_classifier/data.py @@ -0,0 +1,22 @@ +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +def mnist_train_loader(batch_size): + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + train_dataset = datasets.MNIST( + "./data", train=True, download=True, transform=transform + ) + return DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + +def mnist_test_loader(batch_size): + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + test_dataset = datasets.MNIST( + "./data", train=False, download=True, transform=transform + ) + return DataLoader(test_dataset, batch_size=batch_size, shuffle=False) diff --git a/tasks/image_classifier/main.py b/tasks/image_classifier/main.py new file mode 100644 index 0000000..0aa94ea --- /dev/null +++ b/tasks/image_classifier/main.py @@ -0,0 +1,59 @@ +import hydra +from omegaconf import DictConfig +from ray import tune +from ray.tune import Tuner, TuneConfig, RunConfig + +from data import mnist_train_loader, mnist_test_loader +from train import train_mobilenet_tune + + +# Helper to add train/test DataLoader to config (not used by Ray Tune, but can be handy) +def with_data_loader(cfg): + cfg = cfg.copy() + cfg["train_loader"] = mnist_train_loader(batch_size=cfg["batch_size"]) + cfg["test_loader"] = mnist_test_loader(batch_size=cfg["batch_size"]) + return cfg + +search_space = { + "lr": tune.loguniform(1e-4, 1e-2), # Learning rate sampled log-uniformly between 0.0001 and 0.01 + "batch_size": tune.choice([64, 128]), # Batch size sampled from 64 or 128 + "max_num_epochs": 15, # Fixed number of epochs + "num_trials": 2, # Number of Ray Tune trials (not part of search space per trial) + "device": "cuda", # Device setting + "optim": "Adam", # Optimizer choice (fixed as Adam) + "wandb_project": "mobilenetv2-mnist", # WandB project name + "wandb_mode": "online", # WandB logging mode +} +# Main entry point, managed by Hydra for config management +@hydra.main(config_path="conf", config_name="config", version_base=None) +def main(cfg: DictConfig): + # Optionally add DataLoaders to config (not strictly needed for Ray Tune) + cfg = with_data_loader(dict(cfg)) + + # Set up Ray Tune Tuner for hyperparameter search + tuner = Tuner( + tune.with_resources( + tune.with_parameters(train_mobilenet_tune), + resources={"cpu": 2, "gpu": 1 if cfg["device"] == "cuda" else 0}, + ), + tune_config=TuneConfig( + metric="loss", # Optimize for minimum validation loss + mode="min", + num_samples=cfg["num_trials"], # Number of Ray Tune trials + ), + run_config=RunConfig(name="mobilenet_mnist"), + param_space=search_space, # Pass config as search space + ) + + # Run hyperparameter search + results = tuner.fit() + best_result = results.get_best_result("loss", "min") + + # Print best trial configuration and final metrics + print(f"\n Best trial config:\n{best_result.config}") + print(f"Final val accuracy: {best_result.metrics['accuracy']:.4f}") + print(f"Final val loss: {best_result.metrics['loss']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/tasks/image_classifier/model.ipynb b/tasks/image_classifier/model.ipynb new file mode 100644 index 0000000..6ce368e --- /dev/null +++ b/tasks/image_classifier/model.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "9dc0632e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torchinfo as info\n", + "import data\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "seed = 42\n", + "torch.cuda.manual_seed(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "1dcc61e7", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "class BottleneckLayer(nn.Module):\n", + " def __init__(self, in_c, out_c, exp_f, stride=1):\n", + " super().__init__()\n", + "\n", + " self.use_res_connect = (stride == 1 and in_c == out_c)\n", + " mid_c = in_c * exp_f\n", + "\n", + " self.block = nn.Sequential(\n", + "\n", + " nn.Conv2d(in_c, mid_c, kernel_size=1, bias=False),\n", + " nn.BatchNorm2d(mid_c),\n", + " nn.ReLU6(inplace=True),\n", + "\n", + "\n", + " nn.Conv2d(mid_c, mid_c, kernel_size=3, stride=stride,\n", + " padding=1, groups=mid_c, bias=False),\n", + " nn.BatchNorm2d(mid_c),\n", + " nn.ReLU6(inplace=True),\n", + "\n", + " # 1x1 projection\n", + " nn.Conv2d(mid_c, out_c, kernel_size=1, bias=False),\n", + " nn.BatchNorm2d(out_c),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " if self.use_res_connect:\n", + " return x + self.block(x)\n", + " else:\n", + " return self.block(x)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "312d4caf", + "metadata": {}, + "outputs": [], + "source": [ + "class Model(nn.Module):\n", + " def __init__(self, in_c=1, num_classes=10):\n", + " super().__init__()\n", + "\n", + " self.conv1 = nn.Conv2d(in_c, 8, kernel_size=3, stride=1, padding=1) # (1 → 4)\n", + "\n", + " self.block1 = BottleneckLayer(8, 16, exp_f=2, stride=2) # (4 → 8)\n", + " self.block2 = BottleneckLayer(16, 16, exp_f=2, stride=1) # (8 → 8)\n", + " self.block3 = BottleneckLayer(16, 32, exp_f=2, stride=2) # (8 → 16)\n", + " self.block4 = BottleneckLayer(32, 32, exp_f=2, stride=1) # (16 → 16)\n", + "\n", + " self.pool = nn.AdaptiveAvgPool2d((1, 1)) # (B, 16, 1, 1)\n", + " self.fc = nn.Linear(32, num_classes, bias=False)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.conv1(x))\n", + " x = self.block1(x)\n", + " x = self.block2(x)\n", + " x = self.block3(x)\n", + " x = self.block4(x)\n", + " x = self.pool(x)\n", + " x = x.view(x.size(0), -1)\n", + " return self.fc(x)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "4421b015", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=================================================================\n", + "Layer (type:depth-idx) Param #\n", + "=================================================================\n", + "Model --\n", + "├─Conv2d: 1-1 80\n", + "├─BottleneckLayer: 1-2 --\n", + "│ └─Sequential: 2-1 --\n", + "│ │ └─Conv2d: 3-1 128\n", + "│ │ └─BatchNorm2d: 3-2 32\n", + "│ │ └─ReLU6: 3-3 --\n", + "│ │ └─Conv2d: 3-4 144\n", + "│ │ └─BatchNorm2d: 3-5 32\n", + "│ │ └─ReLU6: 3-6 --\n", + "│ │ └─Conv2d: 3-7 256\n", + "│ │ └─BatchNorm2d: 3-8 32\n", + "├─BottleneckLayer: 1-3 --\n", + "│ └─Sequential: 2-2 --\n", + "│ │ └─Conv2d: 3-9 512\n", + "│ │ └─BatchNorm2d: 3-10 64\n", + "│ │ └─ReLU6: 3-11 --\n", + "│ │ └─Conv2d: 3-12 288\n", + "│ │ └─BatchNorm2d: 3-13 64\n", + "│ │ └─ReLU6: 3-14 --\n", + "│ │ └─Conv2d: 3-15 512\n", + "│ │ └─BatchNorm2d: 3-16 32\n", + "├─BottleneckLayer: 1-4 --\n", + "│ └─Sequential: 2-3 --\n", + "│ │ └─Conv2d: 3-17 512\n", + "│ │ └─BatchNorm2d: 3-18 64\n", + "│ │ └─ReLU6: 3-19 --\n", + "│ │ └─Conv2d: 3-20 288\n", + "│ │ └─BatchNorm2d: 3-21 64\n", + "│ │ └─ReLU6: 3-22 --\n", + "│ │ └─Conv2d: 3-23 1,024\n", + "│ │ └─BatchNorm2d: 3-24 64\n", + "├─BottleneckLayer: 1-5 --\n", + "│ └─Sequential: 2-4 --\n", + "│ │ └─Conv2d: 3-25 2,048\n", + "│ │ └─BatchNorm2d: 3-26 128\n", + "│ │ └─ReLU6: 3-27 --\n", + "│ │ └─Conv2d: 3-28 576\n", + "│ │ └─BatchNorm2d: 3-29 128\n", + "│ │ └─ReLU6: 3-30 --\n", + "│ │ └─Conv2d: 3-31 2,048\n", + "│ │ └─BatchNorm2d: 3-32 64\n", + "├─AdaptiveAvgPool2d: 1-6 --\n", + "├─Linear: 1-7 320\n", + "=================================================================\n", + "Total params: 9,504\n", + "Trainable params: 9,504\n", + "Non-trainable params: 0\n", + "=================================================================\n" + ] + } + ], + "source": [ + "model = Model().to(device=device)\n", + "print(info.summary(model))" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "7b3cba56", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "optim = torch.optim.Adam(model.parameters() , lr= 0.001)\n", + "epochs = 20" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "76eb88de", + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, train_loader, test_loader, criterion, optimizer, device, epochs=10):\n", + " model.to(device)\n", + "\n", + " for epoch in range(epochs):\n", + " model.train()\n", + " train_correct = 0\n", + " train_total = 0\n", + " train_loss = 0.0\n", + "\n", + " for xb, yb in train_loader:\n", + " xb, yb = xb.to(device), yb.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(xb)\n", + " loss = criterion(outputs, yb)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_loss += loss.item() * xb.size(0)\n", + " _, predicted = torch.max(outputs, 1)\n", + " train_correct += (predicted == yb).sum().item()\n", + " train_total += yb.size(0)\n", + "\n", + " train_acc = 100 * train_correct / train_total\n", + " avg_train_loss = train_loss / train_total\n", + "\n", + " # --- Evaluation on Test Set ---\n", + " model.eval()\n", + " test_correct = 0\n", + " test_total = 0\n", + " with torch.no_grad():\n", + " for xb, yb in test_loader:\n", + " xb, yb = xb.to(device), yb.to(device)\n", + " outputs = model(xb)\n", + " _, predicted = torch.max(outputs, 1)\n", + " test_correct += (predicted == yb).sum().item()\n", + " test_total += yb.size(0)\n", + "\n", + " test_acc = 100 * test_correct / test_total\n", + "\n", + " print(f\"Epoch {epoch+1}/{epochs} | Loss: {avg_train_loss:.4f} | \"\n", + " f\"Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "43f829b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10 | Loss: 0.4815 | Train Acc: 86.17% | Test Acc: 95.91%\n", + "Epoch 2/10 | Loss: 0.1043 | Train Acc: 96.97% | Test Acc: 97.90%\n", + "Epoch 3/10 | Loss: 0.0734 | Train Acc: 97.79% | Test Acc: 98.17%\n", + "Epoch 4/10 | Loss: 0.0592 | Train Acc: 98.20% | Test Acc: 98.27%\n", + "Epoch 5/10 | Loss: 0.0500 | Train Acc: 98.47% | Test Acc: 98.40%\n", + "Epoch 6/10 | Loss: 0.0457 | Train Acc: 98.58% | Test Acc: 98.44%\n", + "Epoch 7/10 | Loss: 0.0418 | Train Acc: 98.69% | Test Acc: 98.51%\n", + "Epoch 8/10 | Loss: 0.0363 | Train Acc: 98.83% | Test Acc: 98.74%\n", + "Epoch 9/10 | Loss: 0.0353 | Train Acc: 98.85% | Test Acc: 98.77%\n", + "Epoch 10/10 | Loss: 0.0340 | Train Acc: 98.86% | Test Acc: 98.68%\n" + ] + } + ], + "source": [ + "train(model , data.train_loader , data.test_loader , criterion , optim , device , epochs=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd63572c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tasks/image_classifier/model.py b/tasks/image_classifier/model.py new file mode 100644 index 0000000..4ef3b83 --- /dev/null +++ b/tasks/image_classifier/model.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchinfo as info + +import data # assumed to contain DataLoader setup + +# Select device: GPU if available, else CPU +device = "cuda" if torch.cuda.is_available() else "cpu" + +# Set seed for reproducibility +seed = 42 +torch.cuda.manual_seed(seed) # sets seed for current GPU context only + + +class BottleneckLayer(nn.Module): + def __init__(self, in_c, out_c, exp_f, stride=1): + super().__init__() + + # Determine whether residual connection is possible + self.use_res_connect = stride == 1 and in_c == out_c + + # Intermediate (expanded) channel size + mid_c = in_c * exp_f + + # Bottleneck block: Expand → Depthwise → Project + self.block = nn.Sequential( + # 1x1 Convolution (Expansion phase) + nn.Conv2d(in_c, mid_c, kernel_size=1, bias=False), + nn.BatchNorm2d(mid_c), + nn.ReLU6(inplace=True), + # 3x3 Depthwise Convolution + nn.Conv2d( + mid_c, + mid_c, + kernel_size=3, + stride=stride, + padding=1, + groups=mid_c, + bias=False, + ), + nn.BatchNorm2d(mid_c), + nn.ReLU6(inplace=True), + # 1x1 Convolution (Projection phase) + nn.Conv2d(mid_c, out_c, kernel_size=1, bias=False), + nn.BatchNorm2d(out_c), + ) + + def forward(self, x): + # Apply residual connection if allowed + if self.use_res_connect: + return x + self.block(x) + else: + return self.block(x) + + +class Model(nn.Module): + def __init__(self, in_c=1, num_classes=10): + super().__init__() + + # Initial standard convolution: 1×28×28 → 8×28×28 + self.conv1 = nn.Conv2d(in_c, 8, kernel_size=3, stride=1, padding=1) + + # Stack of Bottleneck Layers (MobileNetV2 style) + self.block1 = BottleneckLayer(8, 16, exp_f=2, stride=2) # 8×28×28 → 16×14×14 + self.block2 = BottleneckLayer(16, 16, exp_f=2, stride=1) # 16×14×14 → 16×14×14 + self.block3 = BottleneckLayer(16, 32, exp_f=2, stride=2) # 16×14×14 → 32×7×7 + self.block4 = BottleneckLayer(32, 32, exp_f=2, stride=1) # 32×7×7 → 32×7×7 + + # Global average pooling: 32×7×7 → 32×1×1 + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + + # Final linear classifier (flattened from 32) + self.fc = nn.Linear(32, num_classes, bias=False) + + def forward(self, x): + # Forward pass through each layer + x = F.relu(self.conv1(x)) + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.pool(x) # Shape: (B, 32, 1, 1) + x = x.view(x.size(0), -1) # Flatten to (B, 32) + return self.fc(x) # Final logits (B, 10) + + +# Instantiate model and move to device +model = Model().to(device=device) + +# Print model architecture and parameter summary +print(info.summary(model)) + +# Define loss function and optimizer +criterion = nn.CrossEntropyLoss() +optim = torch.optim.Adam(model.parameters(), lr=0.001) diff --git a/tasks/image_classifier/requirements.txt b/tasks/image_classifier/requirements.txt new file mode 100644 index 0000000..cf6bb15 --- /dev/null +++ b/tasks/image_classifier/requirements.txt @@ -0,0 +1,7 @@ +torch>=2.0 +torchvision>=0.15 +torchinfo>=1.8 +wandb>=0.15 +ray[tune]>=2.0 +hydra-core>=1.3 +omegaconf>=2.3 \ No newline at end of file diff --git a/tasks/image_classifier/train.py b/tasks/image_classifier/train.py new file mode 100644 index 0000000..8fd852d --- /dev/null +++ b/tasks/image_classifier/train.py @@ -0,0 +1,103 @@ +import os +import tempfile + +import torch +import torch.nn as nn +import torch.optim as optim +import wandb +from ray import tune + +from data import mnist_test_loader, mnist_train_loader +from model import Model + + +def train_mobilenet_tune(config): + # Instantiate the model and move to the specified device + model = Model().to(config["device"]) + + # Use DataParallel if multiple GPUs are available + if config["device"] == "cuda" and torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + + # Select optimizer based on config + optimizer = optim.Adam(model.parameters(),lr = config["lr"]) + + criterion = nn.CrossEntropyLoss() + + # Restore from checkpoint if available (for Ray Tune) + if tune.get_checkpoint(): + with tune.get_checkpoint().as_directory() as ckpt_dir: + model_state, opt_state = torch.load(os.path.join(ckpt_dir, "checkpoint.pt")) + model.load_state_dict(model_state) + optimizer.load_state_dict(opt_state) + + # Prepare data loaders + train_loader = mnist_train_loader(batch_size=config["batch_size"]) + test_loader = mnist_test_loader(batch_size=config["batch_size"]) + + # Initialize Weights & Biases logging + wandb.init( + project=config.get("wandb_project", "mnist-raytune"), + config=config, + mode=config.get("wandb_mode", "online"), + reinit=True, + ) + + for epoch in range(config["max_num_epochs"]): + model.train() + total_loss, correct, total = 0.0, 0, 0 + + # Training loop + for xb, yb in train_loader: + xb, yb = xb.to(config["device"]), yb.to(config["device"]) + optimizer.zero_grad() + out = model(xb) + loss = criterion(out, yb) + loss.backward() + optimizer.step() + + total_loss += loss.item() * xb.size(0) + correct += (out.argmax(1) == yb).sum().item() + total += yb.size(0) + + train_acc = correct / total + avg_loss = total_loss / total + + # Validation loop + model.eval() + val_correct, val_total, val_loss = 0, 0, 0.0 + with torch.no_grad(): + for xb, yb in test_loader: + xb, yb = xb.to(config["device"]), yb.to(config["device"]) + out = model(xb) + loss = criterion(out, yb) + val_loss += loss.item() + val_correct += (out.argmax(1) == yb).sum().item() + val_total += yb.size(0) + + val_acc = val_correct / val_total + val_loss /= len(test_loader) + + # Collect metrics for logging and Ray Tune reporting + metrics = { + "epoch": epoch + 1, + "train_acc": train_acc, + "train_loss": avg_loss, + "test_acc": val_acc, + "test_loss": val_loss, + "accuracy": val_acc, + "loss": val_loss, + } + + # Log metrics to wandb + wandb.log(metrics) + + # Save checkpoint and report metrics to Ray Tune + with tempfile.TemporaryDirectory() as tmpdir: + ckpt_path = os.path.join(tmpdir, "checkpoint.pt") + torch.save((model.state_dict(), optimizer.state_dict()), ckpt_path) + checkpoint = tune.Checkpoint.from_directory(tmpdir) + tune.report(metrics, checkpoint=checkpoint) + + # Finish wandb run + wandb.finish() diff --git a/tasks/udp-messenger/cmd/main.go b/tasks/udp-messenger/cmd/main.go new file mode 100644 index 0000000..4284f51 --- /dev/null +++ b/tasks/udp-messenger/cmd/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "strconv" + "udp-messenger/internal/messaging" +) + +/* +main is the entry point of the UDP messenger application. +It parses command-line flags for target and local ports/IP, +creates a UDP socket, and starts the receiver and sender routines. +*/ +func main() { + targetIP := flag.String("target-ip", "127.0.0.1", "Target IP address") + targetPort := flag.Int("target-port", 0, "Target port") + localPort := flag.Int("local-port", 0, "Local port to bind to") + flag.Parse() + + if *targetPort == 0 || *localPort == 0 { + fmt.Println(">>Usage: --local-port --target-port --target-ip ") + flag.PrintDefaults() + os.Exit(1) + } + // Create UDP socket bound to the specified local port. + conn, err := messaging.CreateUDPSocket(*localPort) + if err != nil { + log.Fatal(err) + } + defer conn.Close() + + // Resolve the target UDP address. + targetAddrStr := *targetIP + ":" + strconv.Itoa(*targetPort) + targetAddr, err := net.ResolveUDPAddr("udp", targetAddrStr) + if err != nil { + log.Fatal(">>Invalid target address:", err) + } + + // Start the receiver and sender goroutines. + messaging.StartReceiver(conn) + messaging.StartSender(conn, targetAddr) + + // Block forever to keep the main goroutine alive. + select {} +} diff --git a/tasks/udp-messenger/go.mod b/tasks/udp-messenger/go.mod new file mode 100644 index 0000000..cc1371e --- /dev/null +++ b/tasks/udp-messenger/go.mod @@ -0,0 +1,3 @@ +module udp-messenger + +go 1.22.2 diff --git a/tasks/udp-messenger/internal/messaging/receiver.go b/tasks/udp-messenger/internal/messaging/receiver.go new file mode 100644 index 0000000..990d177 --- /dev/null +++ b/tasks/udp-messenger/internal/messaging/receiver.go @@ -0,0 +1,24 @@ +package messaging + +import ( + "fmt" + "net" +) + +/* +StartReceiver launches a goroutine that listens for incoming UDP messages +on the provided connection and prints them to the console. +*/ +func StartReceiver(conn *net.UDPConn) { + go func() { + buffer := make([]byte, 1024) + for { + n, addr, err := conn.ReadFromUDP(buffer) + if err != nil { + fmt.Println("Error receiving message:", err) + continue + } + fmt.Printf("Received from %s: %s\n", addr.String(), string(buffer[:n])) + } + }() +} diff --git a/tasks/udp-messenger/internal/messaging/sender.go b/tasks/udp-messenger/internal/messaging/sender.go new file mode 100644 index 0000000..f006951 --- /dev/null +++ b/tasks/udp-messenger/internal/messaging/sender.go @@ -0,0 +1,42 @@ +package messaging + +import ( + "bufio" + "fmt" + "net" + "os" + "strings" +) + +/* +StartSender launches a goroutine that reads user input from stdin +and sends it as UDP messages to the specified target address. +Typing 'exit' will terminate the program. +*/ +func StartSender(conn *net.UDPConn, targetAddr *net.UDPAddr) { + go func() { + reader := bufio.NewReader(os.Stdin) + fmt.Println(">>Type your message and press Enter (type 'exit' to quit)") + fmt.Printf(">>Target: %s:%d\n", targetAddr.IP, targetAddr.Port) + + for { + fmt.Print(">> ") + text, err := reader.ReadString('\n') + if err != nil { + fmt.Println(">>Error reading input:", err) + continue + } + + text = strings.TrimSpace(text) + if text == "exit" { + fmt.Println(">>Exiting sender.") + os.Exit(0) + } + + _, err = conn.WriteToUDP([]byte(text), targetAddr) + if err != nil { + fmt.Println(">>Error sending message:", err) + } + } + }() +} diff --git a/tasks/udp-messenger/internal/messaging/socket.go b/tasks/udp-messenger/internal/messaging/socket.go new file mode 100644 index 0000000..2eb04ef --- /dev/null +++ b/tasks/udp-messenger/internal/messaging/socket.go @@ -0,0 +1,29 @@ +package messaging + +import ( + "fmt" + "net" + "strconv" +) + +/* +CreateUDPSocket binds to a static UDP port (provided by localPort) +and returns the UDP connection. It returns an error if the address +cannot be resolved or the socket cannot be bound. +*/ +func CreateUDPSocket(localPort int) (*net.UDPConn, error) { + addrStr := "127.0.0.1:" + strconv.Itoa(localPort) + fmt.Println("Binding to", addrStr) + + udpAddr, err := net.ResolveUDPAddr("udp", addrStr) + if err != nil { + return nil, fmt.Errorf(">>failed to resolve address: %w", err) + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, fmt.Errorf(">>failed to bind to port: %w", err) + } + + return conn, nil +}