-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcommon_helpers.py
More file actions
89 lines (72 loc) · 3.03 KB
/
common_helpers.py
File metadata and controls
89 lines (72 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Helpers function for the request generator"""
import numpy as np
from asyncflow.config.constants import Distribution
from asyncflow.schemas.common.random_variables import RVConfig
def uniform_variable_generator(rng: np.random.Generator) -> float:
"""Return U~Uniform(0, 1)."""
# rng is guaranteed to be a valid np.random.Generator due to the type signature.
return rng.random()
def poisson_variable_generator(
mean: float,
rng: np.random.Generator,
) -> float:
"""Return a Poisson-distributed integer with expectation *mean*."""
return rng.poisson(mean)
def truncated_gaussian_generator(
mean: float,
variance: float,
rng: np.random.Generator,
) -> float:
"""
Generate a Normal-distributed variable
with mean and variance
"""
value = rng.normal(mean, variance)
return max(0.0, value)
def lognormal_variable_generator(
mean: float,
variance: float,
rng: np.random.Generator,
) -> float:
"""Return a Poisson-distributed floateger with expectation *mean*."""
return rng.lognormal(mean, variance)
def exponential_variable_generator(
mean: float,
rng: np.random.Generator,
) -> float:
"""Return an exponentially-distributed float with mean *mean*."""
return float(rng.exponential(mean))
def general_sampler(random_variable: RVConfig, rng: np.random.Generator) -> float:
"""
Draw one sample from the distribution described by *random_variable*.
Only **Normal** and **Log-normal** require an explicit ``variance``.
For **Uniform**, **Poisson** and **Exponential** the mean is enough.
"""
dist = random_variable.distribution
mean = random_variable.mean
var = random_variable.variance
match dist:
# ── No extra parameters needed ──────────────────────────────────
case Distribution.UNIFORM:
# Variance is meaningless for an ad-hoc uniform [0, 1) helper.
assert var is None
return uniform_variable_generator(rng)
case Distribution.POISSON:
# λ == mean ; numpy returns ints → cast to float for consistency
assert var is None
return float(poisson_variable_generator(mean, rng))
case Distribution.EXPONENTIAL:
# β (scale) == mean ; nothing else required
assert var is None
return exponential_variable_generator(mean, rng)
# ── Distributions that *do* need a variance parameter ───────────
case Distribution.NORMAL:
assert var is not None
return truncated_gaussian_generator(mean, var, rng)
case Distribution.LOG_NORMAL:
assert var is not None
return lognormal_variable_generator(mean, var, rng)
# ── Anything else is unsupported ────────────────────────────────
case _:
msg = f"Unsupported distribution: {dist}"
raise ValueError(msg)