Skip to content

CSDI.predict n_sampling_times is not used in func #729

@Durakavalyanie

Description

@Durakavalyanie

1. System Info

In CSDI predict method n_sampling_times arg. is not used and it always returns with one prediction.

2. Information

  • The official example scripts
  • My own created scripts

3. Reproduction

print(X_miss.shape)
X_miss = {"X": X_miss}
predict = csdi.predict(X_miss, n_sampling_times=999)
print(predict['imputation'].shape)

actual output:
(10, 100, 1)
(10, 1, 100, 1)

4. Expected behavior

expected output:
(10, 100, 1)
(10, 999, 100, 1)

5. Your contribution

PyPOTS/pypots/imputation/csdi/model.py

  1. results = self.model(inputs)

to

  1. results = self.model(inputs, n_sampling_times=n_sampling_times)

Metadata

Metadata

Labels

bugSomething isn't workingcompletedThe issue has been completed

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions