Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ jobs:
if [ ! -d InfiniCore ]; then
git clone git@github.com:InfiniTensor/InfiniCore.git &&
cd InfiniCore &&
git checkout f53154df00dc7005cebc49fe9080f1ea21ee1dfa
git submodule update --init &&
git checkout 3c8fb3c05036c95faa2eee42bf8fcc42775edd43
else
echo "InfiniCore already exists"
fi
Expand Down
24 changes: 18 additions & 6 deletions format.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,26 @@
def format_file(file):
file = Path(proj_path.joinpath(file))
print(file)

# Skip if file doesn't exist
if not file.exists():
print(f"Skipping: file does not exist - {file}")
return

if file.suffix in c_style_file:
run(
f"clang-format -style=file -i {file}", cwd=proj_path, shell=True, check=True
)
run(f"git add {file}", cwd=proj_path, shell=True)
try:
run(
f"clang-format -style=file -i {file}", cwd=proj_path, shell=True, check=True
)
run(f"git add {file}", cwd=proj_path, shell=True)
except Exception as e:
print(f"Error formatting file {file}: {e}")
elif file.suffix == py_file:
run(f"black {file}", cwd=proj_path, shell=True, check=True)
run(f"git add {file}", cwd=proj_path, shell=True)
try:
run(f"black {file}", cwd=proj_path, shell=True, check=True)
run(f"git add {file}", cwd=proj_path, shell=True)
except Exception as e:
print(f"Error formatting file {file}: {e}")


if len(sys.argv) == 1:
Expand Down
1 change: 1 addition & 0 deletions include/core/blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class BlobObj {
~BlobObj() {};

template <typename T> T getPtr() const { return reinterpret_cast<T>(ptr); }
void *getRawDataPtr() const { return ptr; }
};

} // namespace infini
Expand Down
2 changes: 1 addition & 1 deletion include/core/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DataType {
case INFINI_DTYPE_U64:
return 8;
case INFINI_DTYPE_F8:
return 1; // 自定义 8-bit float
return 1;
case INFINI_DTYPE_F16:
return 2;
case INFINI_DTYPE_F32:
Expand Down
8 changes: 8 additions & 0 deletions include/core/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,5 +208,13 @@ class StrideExprObj : public BaseExprObj {
Stride getConstantValue() const;
};

inline std::string vecToString(const ShapeExpr &shape) {
return shape->toString();
}

inline std::string vecToString(const StrideExpr &stride) {
return stride->toString();
}

} // namespace infini
#endif // EXPR_H
22 changes: 12 additions & 10 deletions include/core/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,33 @@ struct ContextObj {
infiniDevice_t device = INFINI_DEVICE_CPU;
int deviceId = 0;
infinirtStream_t stream = nullptr;
void *workspace = nullptr;
size_t workspaceSize = 0;
};
using Context = Ref<ContextObj>;

class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
private:
// 全局 map: thread_id -> Context
mutable std::unordered_map<std::thread::id, Context> threadContexts;
mutable std::shared_mutex ctx_mutex;
static thread_local Context tls_context_cache;
size_t workspaceSize;
void *workspace;
static thread_local std::thread::id tls_thread_id;

public:
RuntimeObj() { allocworkspace(); }
RuntimeObj() = default;
RuntimeObj(const RuntimeObj &) = delete;
RuntimeObj &operator=(const RuntimeObj &) = delete;
~RuntimeObj();

// 每个线程唯一的 Runtime
// Unique Runtime per thread
static Runtime &getInstance();

// 每个线程初始化自己的 Context
// Initialize each thread's own Context
void initThreadContext(infiniDevice_t device, int deviceId = 0);

// 获取活跃 Context
// Get active Context
Context getCurrentThreadContext() const;
// Switch device for current thread
void setCurrentDevice(infiniDevice_t device, int deviceId = 0);

static void init();
Expand All @@ -56,15 +58,15 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
infinirtMemcpyKind_t kind, infinirtStream_t stream);
void *mallocAsync(size_t size, infinirtStream_t stream);
void freeAsync(void *ptr, infinirtStream_t stream);
// Synchronize device for current thread
void synchronize() const;
// Get workspace of current Context
size_t getWorkspaceSize() const;
void *getWorkspace(size_t size) const;

bool isCpu() const;

// string toString() const;
private:
void allocworkspace();
// void initWorkspace(size_t size = 7ll << 30);
};
} // namespace infini
#endif // RUNTIME_H
Loading